def fun_exc(x):
    from scipy.stats import norm
    from caiman.components_evaluation import compute_event_exceptionality

    fluo, param = x
    N_samples = np.ceil(param['fr'] * param['decay_time']).astype(np.int)
    ev = compute_event_exceptionality(np.atleast_2d(fluo), N=N_samples)
    return -norm.ppf(np.exp(np.array(ev[1]) / N_samples))
Exemple #2
0
def update_num_components(t,
                          sv,
                          Ab,
                          Cf,
                          Yres_buf,
                          Y_buf,
                          rho_buf,
                          dims,
                          gSig,
                          gSiz,
                          ind_A,
                          CY,
                          CC,
                          groups,
                          oases,
                          gnb=1,
                          rval_thr=0.875,
                          bSiz=3,
                          robust_std=False,
                          N_samples_exceptionality=5,
                          remove_baseline=True,
                          thresh_fitness_delta=-20,
                          thresh_fitness_raw=-20,
                          thresh_overlap=0.5,
                          batch_update_suff_stat=False,
                          sn=None,
                          g=None,
                          lam=0,
                          thresh_s_min=None,
                          s_min=None,
                          Ab_dense=None,
                          max_num_added=1):

    gHalf = np.array(gSiz) // 2

    M = np.shape(Ab)[-1]
    N = M - gnb

    #    Yres = np.array(Yres_buf).T
    #    Y = np.array(Y_buf).T
    #    rhos = np.array(rho_buf).T

    first = True

    sv -= rho_buf.get_first()
    sv += rho_buf.get_last_frames(1).squeeze()

    num_added = 0
    while num_added < max_num_added:

        if first:
            sv_ = sv.copy()  # np.sum(rho_buf,0)
            first = False

        ind = np.argmax(sv_)
        ij = np.unravel_index(ind, dims)
        # ijSig = [[np.maximum(ij[c] - gHalf[c], 0), np.minimum(ij[c] + gHalf[c] + 1, dims[c])]
        #          for c in range(len(ij))]
        # better than above expensive call of numpy and loop creation
        ijSig = [[
            max(ij[0] - gHalf[0], 0),
            min(ij[0] + gHalf[0] + 1, dims[0])
        ], [max(ij[1] - gHalf[1], 0),
            min(ij[1] + gHalf[1] + 1, dims[1])]]

        # xySig = np.meshgrid(*[np.arange(s[0], s[1]) for s in ijSig], indexing='xy')
        # arr = np.array([np.reshape(s, (1, np.size(s)), order='F').squeeze()
        #                 for s in xySig], dtype=np.int)
        # indeces = np.ravel_multi_index(arr, dims, order='F')
        indeces = np.ravel_multi_index(np.ix_(
            np.arange(ijSig[0][0], ijSig[0][1]),
            np.arange(ijSig[1][0], ijSig[1][1])),
                                       dims,
                                       order='F').ravel()

        Ypx = Yres_buf.T[indeces, :]

        ain = np.maximum(np.mean(Ypx, 1), 0)
        na = ain.dot(ain)
        if not na:
            break

        ain /= sqrt(na)

        #        new_res = sv_.copy()
        #        new_res[ np.ravel_multi_index(arr, dims, order='C')] = 10000
        #        cv2.imshow('untitled', 0.1*cv2.resize(new_res.reshape(dims,order = 'C'),(512,512))/2000)
        #        cv2.waitKey(1)

        #        for iter_ in range(15):
        #            cin_res = ain.T.dot(Ypx) / ain.dot(ain)
        #            cin = np.maximum(cin_res, 0)
        #            ain = np.maximum(Ypx.dot(cin.T) / cin.dot(cin), 0)

        ain, cin, cin_res = rank1nmf(Ypx,
                                     ain)  # expects and returns normalized ain

        rval = corr(ain.copy(), np.mean(Ypx, -1))
        #        print(rval)
        if rval > rval_thr:
            # na = sqrt(ain.dot(ain))
            # ain /= na
            # cin = na * cin
            # use sparse Ain only later iff it is actually added to Ab
            Ain = np.zeros((np.prod(dims), 1), dtype=np.float32)
            # Ain = scipy.sparse.csc_matrix((np.prod(dims), 1), dtype=np.float32)
            Ain[indeces, :] = ain[:, None]

            cin_circ = cin.get_ordered()

            #        indeces_good = (Ain[indeces]>0.01).nonzero()[0]

            # rval = np.corrcoef(ain, np.mean(Ypx, -1))[0, 1]

            # rval =
            # np.corrcoef(Ain[indeces_good].toarray().squeeze(),np.mean(Yres[indeces_good,:],-1))[0,1]

            # if rval > rval_thr:
            #            pl.cla()
            #            _ = cm.utils.visualization.plot_contours(Ain, sv.reshape(dims), thr=0.95)
            #            pl.pause(0.01)

            useOASIS = False  # whether to use faster OASIS for cell detection
            if Ab_dense is None:
                ff = np.where(
                    (Ab.T.dot(Ain).T > thresh_overlap)[:, gnb:])[1] + gnb
            else:
                ff = np.where(
                    Ab_dense[indeces,
                             gnb:].T.dot(ain).T > thresh_overlap)[0] + gnb
            if ff.size > 0:
                cc = [corr(cin_circ.copy(), cins) for cins in Cf[ff, :]]

                if np.any(np.array(cc) > .8):
                    #                    repeat = False
                    # vb = imblur(np.reshape(Ain, dims, order='F'),
                    #             sig=gSig, siz=gSiz, nDimBlur=2)
                    # restrict blurring to region where component is located
                    vb = np.reshape(Ain, dims, order='F')
                    slices = tuple(
                        slice(max(0, ijs[0] - 2 * sg),
                              min(d, ijs[1] + 2 * sg)) for ijs, sg, d in zip(
                                  ijSig, gSig, dims))  # is 2 enough?
                    vb[slices] = imblur(vb[slices],
                                        sig=gSig,
                                        siz=gSiz,
                                        nDimBlur=2)
                    sv_ -= (vb.ravel()**2) * cin.dot(cin)

                    #                    pl.imshow(np.reshape(sv,dims));pl.pause(0.001)
                    # print('Overlap at step' + str(t) + ' ' + str(cc))
                    break

            if s_min is None:  # use thresh_s_min * noise estimate * sqrt(1-sum(gamma))
                # the formula has been obtained by running OASIS with s_min=0 and lambda=0 on Gaussin noise.
                # e.g. 1 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the root mean square (non-zero) spike size, sqrt(<s^2>)
                #      2 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the 95% percentile of (non-zero) spike sizes
                #      3 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the 99.7% percentile of (non-zero) spike sizes
                s_min = 0 if thresh_s_min is None else thresh_s_min * \
                    sqrt((ain**2).dot(sn[indeces]**2)) * sqrt(1 - np.sum(g))

            cin_res = cin_res.get_ordered()
            if useOASIS:
                oas = oasis.OASIS(g=g,
                                  s_min=s_min,
                                  num_empty_samples=t + 1 - len(cin_res))
                for yt in cin_res:
                    oas.fit_next(yt)
                foo = oas.get_l_of_last_pool() <= t
                # cc=oas.c
                # print([np.corrcoef(cin_circ,cins)[0,1] for cins in Cf[overlap[0] > 0]])
                # print([np.corrcoef(cc,cins)[0,1] for cins in Cf[overlap[0] > 0, ]])
                # import matplotlib.pyplot as plt
                # plt.plot(cin_res); plt.plot(cc); plt.show()
                # import pdb;pdb.set_trace()
            else:
                fitness_delta, erfc_delta, std_rr, _ = compute_event_exceptionality(
                    np.diff(cin_res)[None, :],
                    robust_std=robust_std,
                    N=N_samples_exceptionality)
                if remove_baseline:
                    num_samps_bl = min(len(cin_res) // 5, 800)
                    bl = scipy.ndimage.percentile_filter(cin_res,
                                                         8,
                                                         size=num_samps_bl)
                else:
                    bl = 0
                fitness_raw, erfc_raw, std_rr, _ = compute_event_exceptionality(
                    (cin_res - bl)[None, :],
                    robust_std=robust_std,
                    N=N_samples_exceptionality)
                foo = (fitness_delta < thresh_fitness_delta) or (
                    fitness_raw < thresh_fitness_raw)

            if foo:
                # print('adding component' + str(N + 1) + ' at timestep ' + str(t))
                num_added += 1
                #                ind_a = uniform_filter(np.reshape(Ain.toarray(), dims, order='F'), size=bSiz)
                #                ind_a = np.reshape(ind_a > 1e-10, (np.prod(dims),), order='F')
                #                indeces_good = np.where(ind_a)[0]#np.where(determine_search_location(Ain,dims))[0]
                if not useOASIS:
                    # TODO: decide on a line to use for setting the parameters
                    # # lambda from init batch (e.g. mean of lambdas)  or  s_min if either s_min or s_min_thresh are given
                    # oas = oasis.OASIS(g=np.ravel(g)[0], lam if s_min==0 else 0, s_min, num_empty_samples=t + 1 - len(cin_res),
                    #                   g2=0 if np.size(g) == 1 else g[1])
                    # or
                    # lambda from Selesnick's 3*sigma*|K| rule
                    # use noise estimate from init batch or use std_rr?
                    #                    sn_ = sqrt((ain**2).dot(sn[indeces]**2)) / sqrt(1 - g**2)
                    sn_ = std_rr
                    oas = oasis.OASIS(
                        np.ravel(g)[0],
                        3 * sn_ / (sqrt(1 - g**2) if np.size(g) == 1 else sqrt(
                            (1 + g[1]) * ((1 - g[1])**2 - g[0]**2) /
                            (1 - g[1]))) if s_min == 0 else 0,
                        s_min,
                        num_empty_samples=t + 1 - len(cin_res),
                        g2=0 if np.size(g) == 1 else g[1])
                    for yt in cin_res:
                        oas.fit_next(yt)

                oases.append(oas)

                Ain_csc = scipy.sparse.csc_matrix(
                    (ain, (indeces, [0] * len(indeces))), (np.prod(dims), 1),
                    dtype=np.float32)

                if Ab_dense is None:
                    groups = update_order(Ab, Ain, groups)[0]
                else:
                    groups = update_order(Ab_dense[indeces], ain, groups)[0]
                csc_append(Ab,
                           Ain_csc)  # faster version of scipy.sparse.hstack
                ind_A.append(Ab.indices[Ab.indptr[M]:Ab.indptr[M + 1]])

                #                ccf = Cf[:,-minibatch_suff_stat:]
                #                CY = ((t*1.-1)/t) * CY + (1./t) * np.dot(ccf, Yr[:, t-minibatch_suff_stat:t].T)/minibatch_suff_stat
                #                CC = ((t*1.-1)/t) * CC + (1./t) * ccf.dot(ccf.T)/minibatch_suff_stat

                tt = t * 1.
                #                if batch_update_suff_stat and Y_buf.cur<len(Y_buf)-1:
                #                   Y_buf_ = Y_buf[Y_buf.cur+1:,:]
                #                   cin_ = cin[Y_buf.cur+1:]
                #                   n_fr_ = len(cin_)
                #                   cin_circ_= cin_circ[-n_fr_:]
                #                   Cf_ = Cf[:,-n_fr_:]
                #                else:
                Y_buf_ = Y_buf
                cin_ = cin
                Cf_ = Cf
                cin_circ_ = cin_circ

                #                CY[M, :] = Y_buf_.T.dot(cin_)[None, :] / tt
                # much faster: exploit that we only access CY[m, ind_pixels],
                # hence update only these
                CY[M, indeces] = cin_.dot(Y_buf_[:, indeces]) / tt
                #                CY = np.vstack([CY[:N,:], Y_buf.T.dot(cin / tt)[None,:], CY[ N:,:]])
                #                YC = CY.T
                #                YC = np.hstack([YC[:, :N], Y_buf.T.dot(cin / tt)[:, None], YC[:, N:]])
                #                CY = YC.T

                # preallocate memory for speed up?
                CC1 = np.hstack([CC, Cf_.dot(cin_circ_ / tt)[:, None]])
                CC2 = np.hstack([(Cf_.dot(cin_circ_)).T,
                                 cin_circ_.dot(cin_circ_)]) / tt
                CC = np.vstack([CC1, CC2])
                Cf = np.vstack([Cf, cin_circ])

                N = N + 1
                M = M + 1

                Yres_buf[:, indeces] -= np.outer(cin, ain)
                # vb = imblur(np.reshape(Ain, dims, order='F'), sig=gSig,
                #             siz=gSiz, nDimBlur=2).ravel()
                # restrict blurring to region where component is located
                vb = np.reshape(Ain, dims, order='F')
                slices = tuple(
                    slice(max(0, ijs[0] - 2 * sg), min(d, ijs[1] + 2 * sg))
                    for ijs, sg, d in zip(ijSig, gSig, dims))  # is 2 enough?
                vb[slices] = imblur(vb[slices], sig=gSig, siz=gSiz, nDimBlur=2)
                vb = vb.ravel()

                # ind_vb = np.where(vb)[0]
                ind_vb = np.ravel_multi_index(
                    np.ix_(*[np.arange(s.start, s.stop) for s in slices]),
                    dims).ravel()

                updt_res = (vb[None, ind_vb].T**2).dot(cin[None, :]**2).T
                rho_buf[:, ind_vb] -= updt_res
                updt_res_sum = np.sum(updt_res, 0)
                sv[ind_vb] -= updt_res_sum
                sv_[ind_vb] -= updt_res_sum

            else:

                num_added = max_num_added

        else:

            num_added = max_num_added

    return Ab, Cf, Yres_buf, rho_buf, CC, CY, ind_A, sv, groups
    pl.imshow(scale(A.toarray(), axis=0).mean(axis=-1).reshape(dims,
                                                               order='F'),
              vmin=-.01,
              vmax=.1,
              cmap='gray')
#%% weighted suff stats
#%% WEIGHTED SUFF STAT
if ploton:
    from caiman.components_evaluation import compute_event_exceptionality
    from scipy.stats import norm

    min_SNR = 2.5
    N_samples = np.ceil(fr * decay_time).astype(
        np.int
    )  # number of timesteps to consider when testing new neuron candidates
    fitness, erf, noi, what = compute_event_exceptionality(
        C + cnm2.noisyC[cnm2.gnb:cnm2.M, t - t // epochs:t], N=N_samples)
    COMP_SNR = -norm.ppf(np.exp(erf / N_samples))
    COMP_SNR = np.vstack([np.ones_like(f), COMP_SNR])
    COMP_SNR = np.clip(COMP_SNR, a_min=0, a_max=100)

    Cf = cnm2.C_on[:cnm2.M, t - t // epochs:t]

    Cf_ = Cf * COMP_SNR
    #    Cf__ = Cf*np.sqrt(COMP_SNR)
    CC_ = Cf_.dot(Cf.T)
    #    CC_ = Cf__.dot(Cf__.T)
    CY_ = Cf_.dot([
        (cv2.resize(yy, dims[::-1]).reshape(-1, order='F') - img_min) /
        img_norm.reshape(-1, order='F') for yy in Y_
    ])
Exemple #4
0
def update_num_components(t,
                          sv,
                          Ab,
                          Cf,
                          Yres_buf,
                          Y_buf,
                          rho_buf,
                          dims,
                          gSig,
                          gSiz,
                          ind_A,
                          CY,
                          CC,
                          groups,
                          oases,
                          gnb=1,
                          rval_thr=0.875,
                          bSiz=3,
                          robust_std=False,
                          N_samples_exceptionality=5,
                          remove_baseline=True,
                          thresh_fitness_delta=-80,
                          thresh_fitness_raw=-20,
                          thresh_overlap=0.25,
                          batch_update_suff_stat=False,
                          sn=None,
                          g=None,
                          thresh_s_min=None,
                          s_min=None,
                          Ab_dense=None,
                          max_num_added=1,
                          min_num_trial=1,
                          loaded_model=None,
                          thresh_CNN_noisy=0.99,
                          sniper_mode=False,
                          use_peak_max=False,
                          test_both=False):
    """
    Checks for new components in the residual buffer and incorporates them if they pass the acceptance tests
    """

    ind_new = []
    gHalf = np.array(gSiz) // 2

    # number of total components (including background)
    M = np.shape(Ab)[-1]
    N = M - gnb  # number of coponents (without background)

    sv -= rho_buf.get_first()
    # update variance of residual buffer
    sv += rho_buf.get_last_frames(1).squeeze()
    sv = np.maximum(sv, 0)

    Ains, Cins, Cins_res, inds, ijsig_all, cnn_pos, local_max = get_candidate_components(
        sv,
        dims,
        Yres_buf=Yres_buf,
        min_num_trial=min_num_trial,
        gSig=gSig,
        gHalf=gHalf,
        sniper_mode=sniper_mode,
        rval_thr=rval_thr,
        patch_size=50,
        loaded_model=loaded_model,
        thresh_CNN_noisy=thresh_CNN_noisy,
        use_peak_max=use_peak_max,
        test_both=test_both)

    ind_new_all = ijsig_all

    num_added = len(inds)
    cnt = 0
    for ind, ain, cin, cin_res in zip(inds, Ains, Cins, Cins_res):
        cnt += 1
        ij = np.unravel_index(ind, dims)

        ijSig = [[max(i - temp_g, 0),
                  min(i + temp_g + 1, d)]
                 for i, temp_g, d in zip(ij, gHalf, dims)]
        dims_ain = (np.abs(np.diff(ijSig[1])[0]), np.abs(np.diff(ijSig[0])[0]))

        indeces = np.ravel_multi_index(
            np.ix_(*[np.arange(ij[0], ij[1]) for ij in ijSig]),
            dims,
            order='F').ravel()

        # use sparse Ain only later iff it is actually added to Ab
        Ain = np.zeros((np.prod(dims), 1), dtype=np.float32)
        Ain[indeces, :] = ain[:, None]

        cin_circ = cin.get_ordered()
        useOASIS = False  # whether to use faster OASIS for cell detection
        accepted = True  # flag indicating new component has not been rejected yet

        if Ab_dense is None:
            ff = np.where((Ab.T.dot(Ain).T > thresh_overlap)[:, gnb:])[1] + gnb
        else:
            ff = np.where(
                Ab_dense[indeces, gnb:].T.dot(ain).T > thresh_overlap)[0] + gnb

        if ff.size > 0:
            #                accepted = False
            cc = [corr(cin_circ.copy(), cins) for cins in Cf[ff, :]]
            if np.any(np.array(cc) > .25) and accepted:
                accepted = False  # reject component as duplicate

        if s_min is None:
            s_min = 0
        # use s_min * noise estimate * sqrt(1-sum(gamma))
        elif s_min < 0:
            # the formula has been obtained by running OASIS with s_min=0 and lambda=0 on Gaussin noise.
            # e.g. 1 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the root mean square (non-zero) spike size, sqrt(<s^2>)
            #      2 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the 95% percentile of (non-zero) spike sizes
            #      3 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the 99.7% percentile of (non-zero) spike sizes
            s_min = -s_min * sqrt(
                (ain**2).dot(sn[indeces]**2)) * sqrt(1 - np.sum(g))

        cin_res = cin_res.get_ordered()
        if accepted:
            if useOASIS:
                oas = oasis.OASIS(g=g,
                                  s_min=s_min,
                                  num_empty_samples=t + 1 - len(cin_res))
                for yt in cin_res:
                    oas.fit_next(yt)
                accepted = oas.get_l_of_last_pool() <= t
            else:
                fitness_delta, erfc_delta, std_rr, _ = compute_event_exceptionality(
                    np.diff(cin_res)[None, :],
                    robust_std=robust_std,
                    N=N_samples_exceptionality)
                if remove_baseline:
                    num_samps_bl = min(len(cin_res) // 5, 800)
                    bl = scipy.ndimage.percentile_filter(cin_res,
                                                         8,
                                                         size=num_samps_bl)
                else:
                    bl = 0
                fitness_raw, erfc_raw, std_rr, _ = compute_event_exceptionality(
                    (cin_res - bl)[None, :],
                    robust_std=robust_std,
                    N=N_samples_exceptionality)
                accepted = (fitness_delta < thresh_fitness_delta) or (
                    fitness_raw < thresh_fitness_raw)

#        if accepted:
#            dims_ain = (np.abs(np.diff(ijSig[1])[0]), np.abs(np.diff(ijSig[0])[0]))
#            thrcomp = threshold_components(ain[:,None],
#                                 dims_ain, medw=None, thr_method='max', maxthr=0.2,
#                                 nrgthr=0.99, extract_cc=True,
#                                 se=None, ss=None)
#
#            sznr = np.sum(thrcomp>0)
#            accepted = (sznr >= np.pi*(np.prod(gSig)/4))
#            if not accepted:
#                print('Rejected because of size')

        if accepted:
            # print('adding component' + str(N + 1) + ' at timestep ' + str(t))
            num_added += 1
            ind_new.append(ijSig)

            if oases is not None:
                if not useOASIS:
                    # lambda from Selesnick's 3*sigma*|K| rule
                    # use noise estimate from init batch or use std_rr?
                    #                    sn_ = sqrt((ain**2).dot(sn[indeces]**2)) / sqrt(1 - g**2)
                    sn_ = std_rr
                    oas = oasis.OASIS(
                        np.ravel(g)[0],
                        3 * sn_ / (sqrt(1 - g**2) if np.size(g) == 1 else sqrt(
                            (1 + g[1]) * ((1 - g[1])**2 - g[0]**2) /
                            (1 - g[1]))) if s_min == 0 else 0,
                        s_min,
                        num_empty_samples=t + 1 - len(cin_res),
                        g2=0 if np.size(g) == 1 else g[1])
                    for yt in cin_res:
                        oas.fit_next(yt)

                oases.append(oas)

            Ain_csc = scipy.sparse.csc_matrix(
                (ain, (indeces, [0] * len(indeces))), (np.prod(dims), 1),
                dtype=np.float32)
            if Ab_dense is None:
                groups = update_order(Ab, Ain, groups)[0]
            else:
                groups = update_order(Ab_dense[indeces], ain, groups)[0]
                Ab_dense = np.hstack((Ab_dense, Ain))
            # faster version of scipy.sparse.hstack
            csc_append(Ab, Ain_csc)
            ind_A.append(Ab.indices[Ab.indptr[M]:Ab.indptr[M + 1]])

            tt = t * 1.
            Y_buf_ = Y_buf
            cin_ = cin
            Cf_ = Cf
            cin_circ_ = cin_circ

            CY[M, indeces] = cin_.dot(Y_buf_[:, indeces]) / tt

            # preallocate memory for speed up?
            CC1 = np.hstack([CC, Cf_.dot(cin_circ_ / tt)[:, None]])
            CC2 = np.hstack([(Cf_.dot(cin_circ_)).T,
                             cin_circ_.dot(cin_circ_)]) / tt
            CC = np.vstack([CC1, CC2])
            Cf = np.vstack([Cf, cin_circ])

            N = N + 1
            M = M + 1

            Yres_buf[:, indeces] -= np.outer(cin, ain)
            # vb = imblur(np.reshape(Ain, dims, order='F'), sig=gSig,
            #             siz=gSiz, nDimBlur=2).ravel()
            # restrict blurring to region where component is located
            #            vb = np.reshape(Ain, dims, order='F')
            slices = tuple(
                slice(max(0, ijs[0] - 2 * sg), min(d, ijs[1] + 2 * sg))
                for ijs, sg, d in zip(ijSig, gSiz // 2, dims))  # is 2 enough?

            slice_within = tuple(
                slice(ijs[0] - sl.start, ijs[1] - sl.start)
                for ijs, sl in zip(ijSig, slices))

            ind_vb = np.ravel_multi_index(
                np.ix_(*[np.arange(ij[0], ij[1]) for ij in ijSig]),
                dims,
                order='C').ravel()

            vb_buf = [
                imblur(np.maximum(
                    0,
                    vb.reshape(dims, order='F')[slices][slice_within]),
                       sig=gSig,
                       siz=gSiz,
                       nDimBlur=len(dims)) for vb in Yres_buf
            ]

            vb_buf2 = np.stack([vb.ravel() for vb in vb_buf])

            #            ind_vb = np.ravel_multi_index(
            #                    np.ix_(*[np.arange(s.start, s.stop)
            #                           for s in slices_small]), dims).ravel()

            rho_buf[:, ind_vb] = vb_buf2**2

            sv[ind_vb] = np.sum(rho_buf[:, ind_vb], 0)
#            sv = np.sum([imblur(vb.reshape(dims,order='F'), sig=gSig, siz=gSiz, nDimBlur=len(dims))**2 for vb in Yres_buf], 0).reshape(-1)
#            plt.subplot(1,5,4)
#            plt.cla()
#            plt.imshow(sv.reshape(dims), vmax=30)
#            plt.pause(.05)
#            plt.subplot(1,5,5)
#            plt.cla()
#            plt.imshow(Yres_buf.mean(0).reshape(dims,order='F'))
#            plt.imshow(np.sum([imblur(vb.reshape(dims,order='F'),\
#                                       sig=gSig, siz=gSiz, nDimBlur=len(dims))**2\
#                                        for vb in Yres_buf],axis=0), vmax=30)
#            plt.pause(.05)

#print(np.min(sv))
#    plt.subplot(1,3,3)
#    plt.cla()
#    plt.imshow(Yres_buf.mean(0).reshape(dims, order = 'F'))
#    plt.pause(.05)
    return Ab, Cf, Yres_buf, rho_buf, CC, CY, ind_A, sv, groups, ind_new, ind_new_all, sv, cnn_pos
Exemple #5
0
def update_num_components(t,
                          sv,
                          Ab,
                          Cf,
                          Yres_buf,
                          Y_buf,
                          rho_buf,
                          dims,
                          gSig,
                          gSiz,
                          ind_A,
                          CY,
                          CC,
                          groups,
                          oases,
                          gnb=1,
                          rval_thr=0.875,
                          bSiz=3,
                          robust_std=False,
                          N_samples_exceptionality=5,
                          remove_baseline=True,
                          thresh_fitness_delta=-80,
                          thresh_fitness_raw=-20,
                          thresh_overlap=0.25,
                          batch_update_suff_stat=False,
                          sn=None,
                          g=None,
                          thresh_s_min=None,
                          s_min=None,
                          Ab_dense=None,
                          max_num_added=1,
                          min_num_trial=1):
    """
    Checks for new components in the residual buffer and incorporates them if they pass the acceptance tests    
    """

    order_rvl = 'C'
    gHalf = np.array(gSiz) // 2

    # number of total components (including background)
    M = np.shape(Ab)[-1]
    N = M - gnb  # number of coponents (without background)

    first = True

    sv -= rho_buf.get_first()
    # update variance of residual buffer
    sv += rho_buf.get_last_frames(1).squeeze()

    num_added = 0
    cnt = 0
    while num_added < max_num_added:
        cnt += 1
        if first:
            sv_ = sv.copy()  # np.sum(rho_buf,0)
            first = False

        ind = np.argmax(sv_)
        ij = np.unravel_index(ind, dims, order=order_rvl)
        # ijSig = [[np.maximum(ij[c] - gHalf[c], 0), np.minimum(ij[c] + gHalf[c] + 1, dims[c])]
        #          for c in range(len(ij))]
        # better than above expensive call of numpy and loop creation

        #        ijSig = [[max(ij[0] - gHalf[0], 0), min(ij[0] + gHalf[0] + 1, dims[0])],
        #                 [max(ij[1] - gHalf[1], 0), min(ij[1] + gHalf[1] + 1, dims[1])]]

        ijSig = [[max(i - g, 0), min(i + g + 1, d)]
                 for i, g, d in zip(ij, gHalf, dims)]

        # xySig = np.meshgrid(*[np.arange(s[0], s[1]) for s in ijSig], indexing='xy')
        # arr = np.array([np.reshape(s, (1, np.size(s)), order='F').squeeze()
        #                 for s in xySig], dtype=np.int)
        # indeces = np.ravel_multi_index(arr, dims, order='F')

        #        indeces = np.ravel_multi_index(np.ix_(np.arange(ijSig[0][0], ijSig[0][1]),
        #                                              np.arange(ijSig[1][0], ijSig[1][1])),
        #                                       dims, order='F').ravel(order=order_rvl)

        indeces = np.ravel_multi_index(
            np.ix_(*[np.arange(ij[0], ij[1]) for ij in ijSig]),
            dims,
            order='F').ravel(order=order_rvl)

        indeces_ = np.ravel_multi_index(
            np.ix_(*[np.arange(ij[0], ij[1]) for ij in ijSig]),
            dims,
            order='C').ravel(order=order_rvl)

        #        indeces_ = np.ravel_multi_index(np.ix_(np.arange(ijSig[0][0], ijSig[0][1]),
        #                                               np.arange(ijSig[1][0], ijSig[1][1])),
        #                                        dims, order='C').ravel(order=order_rvl)

        Ypx = Yres_buf.T[indeces, :]

        ain = np.maximum(np.mean(Ypx, 1), 0)
        na = ain.dot(ain)
        if not na:
            break

        ain /= sqrt(na)

        #        new_res = sv_.copy()
        #        new_res[ np.ravel_multi_index(arr, dims, order='C')] = 10000

        # expects and returns normalized ain
        ain, cin, cin_res = rank1nmf(Ypx, ain)
        # correlation coefficient
        rval = corr(ain.copy(), np.mean(Ypx, -1))

        if rval > rval_thr:
            # use sparse Ain only later iff it is actually added to Ab
            Ain = np.zeros((np.prod(dims), 1), dtype=np.float32)
            Ain[indeces, :] = ain[:, None]

            cin_circ = cin.get_ordered()

            #        indeces_good = (Ain[indeces]>0.01).nonzero()[0]

            useOASIS = False  # whether to use faster OASIS for cell detection
            foo = True  # flag indicating new component has not been rejected yet

            if Ab_dense is None:
                ff = np.where(
                    (Ab.T.dot(Ain).T > thresh_overlap)[:, gnb:])[1] + gnb
            else:
                ff = np.where(
                    Ab_dense[indeces,
                             gnb:].T.dot(ain).T > thresh_overlap)[0] + gnb
            if ff.size > 0:
                foo = False
                cc = [corr(cin_circ.copy(), cins) for cins in Cf[ff, :]]
                if np.any(np.array(cc) > .25) and foo:
                    #                    repeat = False
                    # vb = imblur(np.reshape(Ain, dims, order='F'),
                    #             sig=gSig, siz=gSiz, nDimBlur=2)
                    # restrict blurring to region where component is located
                    vb = np.reshape(Ain, dims, order='C')
                    slices = tuple(
                        slice(max(0, ijs[0] - 2 * sg),
                              min(d, ijs[1] + 2 * sg)) for ijs, sg, d in zip(
                                  ijSig, gSig, dims))  # is 2 enough?
                    vb[slices] = imblur(vb[slices],
                                        sig=gSig,
                                        siz=gSiz,
                                        nDimBlur=len(dims))
                    sv_ -= (vb.ravel(order=order_rvl)**2) * cin.dot(cin)
                    foo = False  # reject component as duplicate
#                    pl.imshow(np.reshape(sv,dims));pl.pause(0.001)
#  print('Overlap at step' + str(t) + ' ' + str(cc))
# break

# use thresh_s_min * noise estimate * sqrt(1-sum(gamma))
            if s_min is None:
                # the formula has been obtained by running OASIS with s_min=0 and lambda=0 on Gaussin noise.
                # e.g. 1 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the root mean square (non-zero) spike size, sqrt(<s^2>)
                #      2 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the 95% percentile of (non-zero) spike sizes
                #      3 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the 99.7% percentile of (non-zero) spike sizes
                s_min = 0 if thresh_s_min is None else thresh_s_min * \
                    sqrt((ain**2).dot(sn[indeces]**2)) * sqrt(1 - np.sum(g))

            cin_res = cin_res.get_ordered()
            if foo:
                if useOASIS:
                    oas = oasis.OASIS(g=g,
                                      s_min=s_min,
                                      num_empty_samples=t + 1 - len(cin_res))
                    for yt in cin_res:
                        oas.fit_next(yt)
                    foo = oas.get_l_of_last_pool() <= t
                else:
                    fitness_delta, erfc_delta, std_rr, _ = compute_event_exceptionality(
                        np.diff(cin_res)[None, :],
                        robust_std=robust_std,
                        N=N_samples_exceptionality)
                    if remove_baseline:
                        num_samps_bl = min(len(cin_res) // 5, 800)
                        bl = scipy.ndimage.percentile_filter(cin_res,
                                                             8,
                                                             size=num_samps_bl)
                    else:
                        bl = 0
                    fitness_raw, erfc_raw, std_rr, _ = compute_event_exceptionality(
                        (cin_res - bl)[None, :],
                        robust_std=robust_std,
                        N=N_samples_exceptionality)
                    foo = (fitness_delta < thresh_fitness_delta) or (
                        fitness_raw < thresh_fitness_raw)

            if foo:
                # print('adding component' + str(N + 1) + ' at timestep ' + str(t))
                num_added += 1
                #                ind_a = uniform_filter(np.reshape(Ain.toarray(), dims, order='F'), size=bSiz)
                #                ind_a = np.reshape(ind_a > 1e-10, (np.prod(dims),), order='F')
                #                indeces_good = np.where(ind_a)[0]#np.where(determine_search_location(Ain,dims))[0]
                if oases is not None:
                    if not useOASIS:
                        # lambda from Selesnick's 3*sigma*|K| rule
                        # use noise estimate from init batch or use std_rr?
                        #                    sn_ = sqrt((ain**2).dot(sn[indeces]**2)) / sqrt(1 - g**2)
                        sn_ = std_rr
                        oas = oasis.OASIS(
                            np.ravel(g)[0],
                            3 * sn_ /
                            (sqrt(1 - g**2) if np.size(g) == 1 else sqrt(
                                (1 + g[1]) * ((1 - g[1])**2 - g[0]**2) /
                                (1 - g[1]))) if s_min == 0 else 0,
                            s_min,
                            num_empty_samples=t + 1 - len(cin_res),
                            g2=0 if np.size(g) == 1 else g[1])
                        for yt in cin_res:
                            oas.fit_next(yt)

                    oases.append(oas)

                Ain_csc = scipy.sparse.csc_matrix(
                    (ain, (indeces, [0] * len(indeces))), (np.prod(dims), 1),
                    dtype=np.float32)

                if Ab_dense is None:
                    groups = update_order(Ab, Ain, groups)[0]
                else:
                    groups = update_order(Ab_dense[indeces], ain, groups)[0]
                # faster version of scipy.sparse.hstack
                csc_append(Ab, Ain_csc)
                ind_A.append(Ab.indices[Ab.indptr[M]:Ab.indptr[M + 1]])

                tt = t * 1.
                #                if batch_update_suff_stat and Y_buf.cur<len(Y_buf)-1:
                #                   Y_buf_ = Y_buf[Y_buf.cur+1:,:]
                #                   cin_ = cin[Y_buf.cur+1:]
                #                   n_fr_ = len(cin_)
                #                   cin_circ_= cin_circ[-n_fr_:]
                #                   Cf_ = Cf[:,-n_fr_:]
                #                else:
                Y_buf_ = Y_buf
                cin_ = cin
                Cf_ = Cf
                cin_circ_ = cin_circ

                #                CY[M, :] = Y_buf_.T.dot(cin_)[None, :] / tt
                # much faster: exploit that we only access CY[m, ind_pixels],
                # hence update only these
                CY[M, indeces] = cin_.dot(Y_buf_[:, indeces]) / tt

                # preallocate memory for speed up?
                CC1 = np.hstack([CC, Cf_.dot(cin_circ_ / tt)[:, None]])
                CC2 = np.hstack([(Cf_.dot(cin_circ_)).T,
                                 cin_circ_.dot(cin_circ_)]) / tt
                CC = np.vstack([CC1, CC2])
                Cf = np.vstack([Cf, cin_circ])

                N = N + 1
                M = M + 1

                Yres_buf[:, indeces] -= np.outer(cin, ain)
                # vb = imblur(np.reshape(Ain, dims, order='F'), sig=gSig,
                #             siz=gSiz, nDimBlur=2).ravel()
                # restrict blurring to region where component is located
                vb = np.reshape(Ain, dims, order='F')
                slices = tuple(
                    slice(max(0, ijs[0] - 2 * sg), min(d, ijs[1] + 2 * sg))
                    for ijs, sg, d in zip(ijSig, gSig, dims))  # is 2 enough?
                vb[slices] = imblur(vb[slices],
                                    sig=gSig,
                                    siz=gSiz,
                                    nDimBlur=len(dims))
                vb = vb.ravel(order=order_rvl)

                # ind_vb = np.where(vb)[0]
                ind_vb = np.ravel_multi_index(
                    np.ix_(*[np.arange(s.start, s.stop) for s in slices]),
                    dims,
                    order=order_rvl).ravel(order=order_rvl)

                updt_res = (vb[None, ind_vb].T**2).dot(cin[None, :]**2).T
                rho_buf[:, ind_vb] -= updt_res
                updt_res_sum = np.sum(updt_res, 0)
                sv[ind_vb] -= updt_res_sum
                sv_[ind_vb] -= updt_res_sum

            else:
                if cnt >= min_num_trial:
                    num_added = max_num_added
                else:
                    first = False
                    sv_[indeces_] = 0

        else:
            if cnt >= min_num_trial:
                num_added = max_num_added
            else:
                first = False
                sv_[indeces_] = 0

    return Ab, Cf, Yres_buf, rho_buf, CC, CY, ind_A, sv, groups
Exemple #6
0
def trace_correlation(data_path,
                      agonia_th,
                      select_cells=False,
                      plot_results=True,
                      denoise=False):
    '''Calculate the correlation between the mean of the Agonia Box and the CaImAn
    factor.
    Parameters
    ----------
    data_path : string
        path to folder containing the data
    agonia_th : float
        threshold for detection confidence for each box
    select_cells: bool, optional
        if True get index of active cells
    plot_results: bool, optional
        if True do boxplot of correlation values for all cells, if selected_cells,
        use only active cells
    denoise : bool, optional
        if True subtract neuropil
    Returns
    -------
    cell_corr : numpy array
        corrcoef value for all cells
    idx_active : list
        if select_cells=True returns index of active cells, otherwise returns
        index of all cells
    boxes_traces : NxT ndarray
        temporal trace for the N box that has an asociated caiman factor'''

    data_name, median_projection, fnames, fname_new, results_caiman_path, boxes_path = get_files_names(
        data_path)
    # load Caiman results
    cnm = cnmf.load_CNMF(results_caiman_path)
    # calculate the centers of the CaImAn factors
    centers = np.empty((cnm.estimates.A.shape[1], 2))
    for i, factor in enumerate(cnm.estimates.A.T):
        centers[i] = center_of_mass(factor.toarray().reshape(
            cnm.estimates.dims, order='F'))
    # load boxes
    with open(boxes_path, 'rb') as f:
        boxes = pickle.load(f)
        f.close()
    # keep only cells above confidence threshold
    boxes = boxes[boxes[:, 4] > agonia_th].astype('int')

    #delete boxes that do not have a caiman cell inside
    k = 0
    for cell, box in enumerate(boxes):
        idx_factor = [
            i for i, center in enumerate(centers)
            if center[0] > box[1] and center[0] < box[3] and center[1] > box[0]
            and center[1] < box[2]
        ]
        if not idx_factor:
            boxes = np.delete(boxes, cell - k, axis=0)
            k += 1

    # Load video as 3D tensor (each plane is a frame)
    Yr, dims, T = cm.load_memmap(fname_new)
    images = np.reshape(Yr.T, [T] + list(dims), order='F')
    boxes_traces = np.empty((boxes.shape[0], images.shape[0]))
    #calculate correlations between boxes and CaImAn factors
    cell_corr = np.empty(len(boxes_traces))
    neuropil_trace = np.zeros(T)
    if denoise:
        _, neuropil_trace, neuropil_power = substract_neuropil(
            data_path, agonia_th, 100, 80)
    for cell, box in enumerate(boxes):
        # calculate boxes traces as means over images
        #boxes_traces[cell] = images[:,box[1]:box[3],box[0]:box[2]].mean(axis=(1,2))-neuropil_trace

        boxes_traces[cell] = images[:, box[1]:box[3],
                                    box[0]:box[2]].mean(axis=(1, 2))

        #for using the percentile criteria
        med = np.median(images[:, box[1]:box[3], box[0]:box[2]], axis=0)
        box_trace = images[:, box[1]:box[3], box[0]:box[2]]
        boxes_traces[cell] = box_trace[:,
                                       np.logical_and(
                                           med > np.percentile(med, 80),
                                           med < np.percentile(med, 95))].mean(
                                               axis=1)
        boxes_traces[
            cell] = boxes_traces[cell] - neuropil_trace * neuropil_power * .7
        #boxes_traces[cell] = boxes_traces[cell]-neuropil_trace*boxes_traces[cell].mean()
        # get the asociated CaImAn factor by checking if its center of mass is inside the box
        idx_factor = [
            i for i, center in enumerate(centers)
            if center[0] > box[1] and center[0] < box[3] and center[1] > box[0]
            and center[1] < box[2]
        ]
        # in case there is more than one center inside the box choose the one closer to the center of the box
        if len(idx_factor) > 1:
            idx_factor = [
                idx_factor[np.argmin([
                    np.linalg.norm([(box[3] - box[1]) / 2, (box[2] - box[0]) /
                                    2] - c) for c in centers[idx_factor]
                ])]
            ]
        cell_corr[cell] = np.corrcoef(
            [cnm.estimates.C[idx_factor[0]], boxes_traces[cell]])[1, 0]

    if select_cells:
        #select only active cells using CaImAn criteria
        fitness, _, _, _ = compute_event_exceptionality(boxes_traces)
        idx_active = [cell for cell, fit in enumerate(fitness) if fit < -20]
    else:
        idx_active = [cell for cell, _ in enumerate(boxes_traces)]

    if plot_results:
        corr_toplot = [
            corr for id, corr in enumerate(cell_corr)
            if id in idx_active and ~np.isnan(corr)
        ]
        fig, ax = plt.subplots()
        p = ax.boxplot(corr_toplot)
        ax.set_ylim([-1, 1])
        plt.text(1.2, 0.8, 'n_cells = {}'.format(len(corr_toplot)))

    return cell_corr, idx_active, boxes_traces
    Cn = mdff.local_correlations(eight_neighbours=True, swap_dim=False)
    pl.figure()
    crd = cm.utils.visualization.plot_contours(A, Cn, thr=0.9)

#%%
view_patches_bar(Yr, scipy.sparse.coo_matrix(A.tocsc()[:, :]), C[:, :], b, f,
                 dims[0], dims[1], YrA=noisyC - C, img=Cn)
#%% WEIGHTED SUFF STAT
ploton = True
if ploton:
    from caiman.components_evaluation import compute_event_exceptionality
    from scipy.stats import norm

    min_SNR = 2.5
    N_samples = np.ceil(params_movie[ind_dataset]['fr']*params_movie[ind_dataset]['decay_time']).astype(np.int)   # number of timesteps to consider when testing new neuron candidates
    fitness, erf, noi, what = compute_event_exceptionality(C+cnm2.noisyC[cnm2.gnb:cnm2.M, t - t // epochs:t],N=N_samples)
    COMP_SNR = -norm.ppf(np.exp(erf/ N_samples))
    COMP_SNR  = np.vstack([np.ones_like(f),COMP_SNR])
    COMP_SNR = np.clip(COMP_SNR, a_min = 0, a_max = 100)

    Cf = cnm2.C_on[:cnm2.M, t-t//epochs:t]

    Cf_ = Cf*COMP_SNR
#    Cf__ = Cf*np.sqrt(COMP_SNR)
    CC_ = Cf_.dot(Cf.T)
#    CC_ = Cf__.dot(Cf__.T)
    CY_ = Cf_.dot([(cv2.resize(yy,dims[::-1]).reshape(-1,order='F')-img_min)/img_norm.reshape(-1, order='F') for yy in Y_])

    Ab_, ind_A_, Ab_dense_ = cm.source_extraction.cnmf.online_cnmf.update_shapes(CY_, CC_, cnm2.Ab.copy(), cnm2.ind_A, indicator_components=None, Ab_dense=None, update_bkgrd=True, iters=55)
    #%%
    A_, b_ = Ab_[:, cnm2.gnb:], Ab_[:, :cnm2.gnb].toarray()
def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
                          dims, gSig, gSiz, ind_A, CY, CC, groups, oases, gnb=1,
                          rval_thr=0.875, bSiz=3, robust_std=False,
                          N_samples_exceptionality=5, remove_baseline=True,
                          thresh_fitness_delta=-80, thresh_fitness_raw=-20,
                          thresh_overlap=0.25, batch_update_suff_stat=False,
                          sn=None, g=None, thresh_s_min=None, s_min=None,
                          Ab_dense=None, max_num_added=1, min_num_trial=1,
                          loaded_model=None, thresh_CNN_noisy=0.99,
                          sniper_mode=False, use_peak_max=False,
                          test_both=False):
    """
    Checks for new components in the residual buffer and incorporates them if they pass the acceptance tests
    """

    ind_new = []
    gHalf = np.array(gSiz) // 2

    # number of total components (including background)
    M = np.shape(Ab)[-1]
    N = M - gnb                 # number of coponents (without background)

    sv -= rho_buf.get_first()
    # update variance of residual buffer
    sv += rho_buf.get_last_frames(1).squeeze()
    sv = np.maximum(sv, 0)

    Ains, Cins, Cins_res, inds, ijsig_all, cnn_pos, local_max = get_candidate_components(sv, dims, Yres_buf=Yres_buf,
                                                          min_num_trial=min_num_trial, gSig=gSig,
                                                          gHalf=gHalf, sniper_mode=sniper_mode, rval_thr=rval_thr, patch_size=50,
                                                          loaded_model=loaded_model, thresh_CNN_noisy=thresh_CNN_noisy,
                                                          use_peak_max=use_peak_max, test_both=test_both)

    ind_new_all = ijsig_all

    num_added = len(inds)
    cnt = 0
    for ind, ain, cin, cin_res in zip(inds, Ains, Cins, Cins_res):
        cnt += 1
        ij = np.unravel_index(ind, dims)

        ijSig = [[max(i - temp_g, 0), min(i + temp_g + 1, d)] for i, temp_g, d in zip(ij, gHalf, dims)]
        dims_ain = (np.abs(np.diff(ijSig[1])[0]), np.abs(np.diff(ijSig[0])[0]))

        indeces = np.ravel_multi_index(
                np.ix_(*[np.arange(ij[0], ij[1])
                       for ij in ijSig]), dims, order='F').ravel()

        # use sparse Ain only later iff it is actually added to Ab
        Ain = np.zeros((np.prod(dims), 1), dtype=np.float32)
        Ain[indeces, :] = ain[:, None]

        cin_circ = cin.get_ordered()
        useOASIS = False  # whether to use faster OASIS for cell detection
        accepted = True   # flag indicating new component has not been rejected yet

        if Ab_dense is None:
            ff = np.where((Ab.T.dot(Ain).T > thresh_overlap)
                          [:, gnb:])[1] + gnb
        else:
            ff = np.where(Ab_dense[indeces, gnb:].T.dot(
                ain).T > thresh_overlap)[0] + gnb

        if ff.size > 0:
#                accepted = False
            cc = [corr(cin_circ.copy(), cins) for cins in Cf[ff, :]]
            if np.any(np.array(cc) > .25) and accepted:
                accepted = False         # reject component as duplicate

        if s_min is None:
            s_min = 0
        # use s_min * noise estimate * sqrt(1-sum(gamma))
        elif s_min < 0:
            # the formula has been obtained by running OASIS with s_min=0 and lambda=0 on Gaussin noise.
            # e.g. 1 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the root mean square (non-zero) spike size, sqrt(<s^2>)
            #      2 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the 95% percentile of (non-zero) spike sizes
            #      3 * sigma * sqrt(1-sum(gamma)) corresponds roughly to the 99.7% percentile of (non-zero) spike sizes
            s_min = -s_min * sqrt((ain**2).dot(sn[indeces]**2)) * sqrt(1 - np.sum(g))

        cin_res = cin_res.get_ordered()
        if accepted:
            if useOASIS:
                oas = oasis.OASIS(g=g, s_min=s_min,
                                  num_empty_samples=t + 1 - len(cin_res))
                for yt in cin_res:
                    oas.fit_next(yt)
                accepted = oas.get_l_of_last_pool() <= t
            else:
                fitness_delta, erfc_delta, std_rr, _ = compute_event_exceptionality(
                    np.diff(cin_res)[None, :], robust_std=robust_std, N=N_samples_exceptionality)
                if remove_baseline:
                    num_samps_bl = min(len(cin_res) // 5, 800)
                    bl = scipy.ndimage.percentile_filter(
                        cin_res, 8, size=num_samps_bl)
                else:
                    bl = 0
                fitness_raw, erfc_raw, std_rr, _ = compute_event_exceptionality(
                    (cin_res - bl)[None, :], robust_std=robust_std,
                    N=N_samples_exceptionality)
                accepted = (fitness_delta < thresh_fitness_delta) or (
                    fitness_raw < thresh_fitness_raw)

#        if accepted:
#            dims_ain = (np.abs(np.diff(ijSig[1])[0]), np.abs(np.diff(ijSig[0])[0]))
#            thrcomp = threshold_components(ain[:,None],
#                                 dims_ain, medw=None, thr_method='max', maxthr=0.2,
#                                 nrgthr=0.99, extract_cc=True,
#                                 se=None, ss=None)
#
#            sznr = np.sum(thrcomp>0)
#            accepted = (sznr >= np.pi*(np.prod(gSig)/4))
#            if not accepted:
#                print('Rejected because of size')

        if accepted:
            # print('adding component' + str(N + 1) + ' at timestep ' + str(t))
            num_added += 1
            ind_new.append(ijSig)

            if oases is not None:
                if not useOASIS:
                    # lambda from Selesnick's 3*sigma*|K| rule
                    # use noise estimate from init batch or use std_rr?
                    #                    sn_ = sqrt((ain**2).dot(sn[indeces]**2)) / sqrt(1 - g**2)
                    sn_ = std_rr
                    oas = oasis.OASIS(np.ravel(g)[0], 3 * sn_ /
                                      (sqrt(1 - g**2) if np.size(g) == 1 else
                                       sqrt((1 + g[1]) * ((1 - g[1])**2 - g[0]**2) / (1 - g[1])))
                                      if s_min == 0 else 0,
                                      s_min, num_empty_samples=t +
                                      1 - len(cin_res),
                                      g2=0 if np.size(g) == 1 else g[1])
                    for yt in cin_res:
                        oas.fit_next(yt)

                oases.append(oas)

            Ain_csc = scipy.sparse.csc_matrix((ain, (indeces, [0] * len(indeces))),
                                              (np.prod(dims), 1), dtype=np.float32)
            if Ab_dense is None:
                groups = update_order(Ab, Ain, groups)[0]
            else:
                groups = update_order(Ab_dense[indeces], ain, groups)[0]
                Ab_dense = np.hstack((Ab_dense,Ain))
            # faster version of scipy.sparse.hstack
            csc_append(Ab, Ain_csc)
            ind_A.append(Ab.indices[Ab.indptr[M]:Ab.indptr[M + 1]])

            tt = t * 1.
            Y_buf_ = Y_buf
            cin_ = cin
            Cf_ = Cf
            cin_circ_ = cin_circ

            CY[M, indeces] = cin_.dot(Y_buf_[:, indeces]) / tt

            # preallocate memory for speed up?
            CC1 = np.hstack([CC, Cf_.dot(cin_circ_ / tt)[:, None]])
            CC2 = np.hstack(
                [(Cf_.dot(cin_circ_)).T, cin_circ_.dot(cin_circ_)]) / tt
            CC = np.vstack([CC1, CC2])
            Cf = np.vstack([Cf, cin_circ])

            N = N + 1
            M = M + 1

            Yres_buf[:, indeces] -= np.outer(cin, ain)
            # vb = imblur(np.reshape(Ain, dims, order='F'), sig=gSig,
            #             siz=gSiz, nDimBlur=2).ravel()
            # restrict blurring to region where component is located
#            vb = np.reshape(Ain, dims, order='F')
            slices = tuple(slice(max(0, ijs[0] - 2*sg), min(d, ijs[1] + 2*sg))
                           for ijs, sg, d in zip(ijSig, gSiz//2, dims))  # is 2 enough?

            slice_within = tuple(slice(ijs[0] - sl.start, ijs[1] - sl.start)
                           for ijs, sl in zip(ijSig, slices))


            ind_vb = np.ravel_multi_index(
               np.ix_(*[np.arange(ij[0], ij[1])
                      for ij in ijSig]), dims, order='C').ravel()




            vb_buf = [imblur(np.maximum(0,vb.reshape(dims,order='F')[slices][slice_within]), sig=gSig, siz=gSiz, nDimBlur=len(dims)) for vb in Yres_buf]

            vb_buf2 = np.stack([vb.ravel() for vb in vb_buf])

#            ind_vb = np.ravel_multi_index(
#                    np.ix_(*[np.arange(s.start, s.stop)
#                           for s in slices_small]), dims).ravel()

            rho_buf[:, ind_vb] = vb_buf2**2


            sv[ind_vb] = np.sum(rho_buf[:, ind_vb], 0)
#            sv = np.sum([imblur(vb.reshape(dims,order='F'), sig=gSig, siz=gSiz, nDimBlur=len(dims))**2 for vb in Yres_buf], 0).reshape(-1)
#            plt.subplot(1,5,4)
#            plt.cla()
#            plt.imshow(sv.reshape(dims), vmax=30)
#            plt.pause(.05)
#            plt.subplot(1,5,5)
#            plt.cla()
#            plt.imshow(Yres_buf.mean(0).reshape(dims,order='F'))
#            plt.imshow(np.sum([imblur(vb.reshape(dims,order='F'),\
#                                       sig=gSig, siz=gSiz, nDimBlur=len(dims))**2\
#                                        for vb in Yres_buf],axis=0), vmax=30)
#            plt.pause(.05)

    #print(np.min(sv))
#    plt.subplot(1,3,3)
#    plt.cla()
#    plt.imshow(Yres_buf.mean(0).reshape(dims, order = 'F'))
#    plt.pause(.05)
    return Ab, Cf, Yres_buf, rho_buf, CC, CY, ind_A, sv, groups, ind_new, ind_new_all, sv, cnn_pos
#%%
from scipy.stats import norm
#a = cm.load('example_movies/demoMovie.tif')
#a = cm.load('/mnt/ceph/neuro/labeling/yuste.Single_150u/images/tifs/Single_150um_024.tif')
#a = cm.load('/Users/agiovann/example_movies_ALL/quietBlock_2_ds_2_2.hdf5')
a = cm.load('/mnt/ceph/neuro/labeling/yuste.Single_150u/images/tifs/Yr_d1_200_d2_256_d3_1_order_C_frames_3000_.mmap')

all_els = []
for it in range(1):
    print(it)
#    a = cm.movie(np.random.randn(*mns.shape).astype(np.float32))
    Yr = remove_baseline_fast(np.array(cm.movie.to_2D(a)).T).T

    #%
#    norm = lambda(x): np.exp(-x**2/2)/np.sqrt(2*np.pi)
    fitness, res, sd_r, md = compute_event_exceptionality(Yr.T)
    Yr_c = -np.log(norm.sf(np.array((Yr - md) /
                                            sd_r, dtype=np.float)))
    mns = cm.movie(scipy.ndimage.convolve(np.reshape(
        Yr_c, [-1, a.shape[1], a.shape[2]], order='F'), np.ones([5, 3, 3])))
    mns[mns < (38 * np.log10((mns.shape[0])))] = 0
    all_els.append(np.sum(mns > 0)/Yr.size)
    print(all_els)


#%%
#m1 = cm.movie((np.array((Yr-md)/sd_r)).reshape([-1,60,80],order = 'F'))*(scipy.ndimage.convolve(mns>0,np.ones([5,3,3])))
m1 = cm.movie((np.array((Yr - md) / sd_r)
               ).reshape([-1, a.shape[1],a.shape[2]], order='F')) * (mns > 0)
#%%
m2 = cm.movie((np.array((Yr-md)/sd_r)).reshape(m1.shape,order = 'F'))*(scipy.ndimage.convolve(mns>0,np.ones([5,3,3])))