def cv_select_lambda(y, s_vec, t_vec, lambda_seq, L_min_seq, gamma=0.96):

    train_y = y[0::2]
    test_y = y[1::2]
    mse_c_l0 = []
    mse_c_l1 = []
    spike_l0 = []
    spike_l1 = []
    result_dict = dict()

    for lam in lambda_seq:
        print('lambda = ', lam)
        c_t_l1, s_t_l1 = oasisAR1(train_y, g=gamma, lam=lam, s_min=0)
        _, s_t_l0, c_t_l0 = dp_ell_0_active_sol(train_y,
                                                alpha=lam,
                                                gamma=gamma)

        test_c_t_l0 = [(a + b) / 2 for a, b in \
                       zip(c_t_l0[0:-1], c_t_l0[1:])]
        test_c_t_l1 = [(a + b) / 2 for a, b in \
                       zip(c_t_l1[0:-1], c_t_l1[1:])]

        mse_c_l0.append(mean_squared_error(test_c_t_l0, test_y[0:-1]))
        mse_c_l1.append(mean_squared_error(test_c_t_l1, test_y[0:-1]))

        c_t_l1, s_t_l1 = oasisAR1(y, g=gamma, lam=lam, s_min=0)
        _, s_t_l0, c_t_l0 = dp_ell_0_active_sol(y, alpha=lam, gamma=gamma)

        spike_l0.append(t_vec[s_t_l0])
        spike_l1.append([t_vec[s_t_l1 > L_min] for L_min in L_min_seq])

    result_dict['lambda_seq'] = lambda_seq
    result_dict['L_min_seq'] = L_min_seq  # :(
    result_dict['mse_c_l0'] = mse_c_l0
    result_dict['mse_c_l1'] = mse_c_l1
    result_dict['spike_l0'] = spike_l0
    result_dict['spike_l1'] = spike_l1

    return (result_dict)
Example #2
0
]
# real data from Chen et al 2013, available at following URL
# https://portal.nersc.gov/project/crcns/download/cai-1/GCaMP6s_9cells_Chen2013/processed_data.tar.gz
filename = "/Users/joe/Downloads/data_20120627_cell2_002.mat"

################
#### Traces ####
################

# AR(1)

g = .95
sn = .3
Y, trueC, trueSpikes = gen_data()
N, T = Y.shape
result_oasis = oasisAR1(Y[0], g=g, lam=2.4)
result_foopsi = foopsi(Y[0], g=[g], lam=2.4)

fig = plt.figure(figsize=(20, 5.5))
fig.add_axes([.038, .57, .96, .42])
plt.plot(result_oasis[0], c=col[0], label='OASIS')
plt.plot(result_foopsi[0], '--', c=col[6], label='CVXPY')
plt.plot(trueC[0], c=col[2], label='Truth', zorder=-5)
plt.plot(Y[0], c=col[7], alpha=.7, zorder=-10, lw=1, label='Data')
plt.legend(frameon=False, ncol=4, loc=(.275, .82))
plt.gca().set_xticklabels([])
simpleaxis(plt.gca())
plt.ylim(Y[0].min(), Y[0].max())
plt.yticks([0, int(Y[0].max())], [0, int(Y[0].max())])
plt.xticks(range(750, 3000, 750), [''] * 3)
plt.ylabel('Fluor.')
Example #3
0
def constrained_onnlsAR2(y,
                         g,
                         sn,
                         optimize_b=True,
                         b_nonneg=True,
                         optimize_g=0,
                         decimate=5,
                         shift=100,
                         window=None,
                         tol=1e-9,
                         max_iter=1,
                         penalty=1):
    """ Infer the most likely discretized spike train underlying an AR(2) fluorescence trace

    Solves the noise constrained sparse non-negative deconvolution problem
    min |s|_1 subject to |c-y|^2 = sn^2 T and s_t = c_t-g1 c_{t-1}-g2 c_{t-2} >= 0

    Parameters
    ----------
    y : array of float
        One dimensional array containing the fluorescence intensities (with baseline
        already subtracted) with one entry per time-bin.
    g : (float, float)
        Parameters of the AR(2) process that models the fluorescence impulse response.
    sn : float
        Standard deviation of the noise distribution.
    optimize_b : bool, optional, default True
        Optimize baseline if True else it is set to 0, see y.
    b_nonneg: bool, optional, default True
        Enforce strictly non-negative baseline if True.
    optimize_g : int, optional, default 0
        Number of large, isolated events to consider for optimizing g.
        No optimization if optimize_g=0.
    decimate : int, optional, default 5
        Decimation factor for estimating hyper-parameters faster on decimated data.
    shift : int, optional, default 100
        Number of frames by which to shift window from on run of NNLS to the next.
    window : int, optional, default None (200 or larger dependend on g)
        Window size.
    tol : float, optional, default 1e-9
        Tolerance parameter.
    max_iter : int, optional, default 1
        Maximal number of iterations.
    penalty : int, optional, default 1
        Sparsity penalty. 1: min |s|_1  0: min |s|_0

    Returns
    -------
    c : array of float
        The inferred denoised fluorescence signal at each time-bin.
    s : array of float
        Discretized deconvolved neural activity (spikes).
    b : float
        Fluorescence baseline value.
    (g1, g2) : tuple of float
        Parameters of the AR(2) process that models the fluorescence impulse response.
    lam : float
        Sparsity penalty parameter lambda of dual problem.

    References
    ----------
    * Friedrich J and Paninski L, NIPS 2016
    * Friedrich J, Zhou P, and Paninski L, PLOS Computational Biology 2017
    """

    T = len(y)
    d = (g[0] + sqrt(g[0] * g[0] + 4 * g[1])) / 2
    r = (g[0] - sqrt(g[0] * g[0] + 4 * g[1])) / 2
    if window is None:
        window = int(min(T, max(200, -5 / log(d))))
    if not optimize_g:
        g11 = (np.exp(log(d) * np.arange(1, T + 1)) * np.arange(1, T + 1)) \
            if d == r else \
            (np.exp(log(d) * np.arange(1, T + 1)) -
             np.exp(log(r) * np.arange(1, T + 1))) / (d - r)
        g12 = np.append(0, g[1] * g11[:-1])
        g11g11 = np.cumsum(g11 * g11)
        g11g12 = np.cumsum(g11 * g12)
        Sg11 = np.cumsum(g11)
        f_lam = 1 - g[0] - g[1]
    elif decimate == 0:  # need to run AR1 anyways for estimating AR coeffs
        decimate = 1
    thresh = sn * sn * T
    # get initial estimate of b and lam on downsampled data using AR1 model
    if decimate > 0:
        _, s, b, aa, lam = constrained_oasisAR1(
            y[:len(y) // decimate * decimate].reshape(-1, decimate).mean(1),
            d**decimate,
            sn / sqrt(decimate),
            optimize_b=optimize_b,
            b_nonneg=b_nonneg,
            optimize_g=optimize_g)
        if optimize_g:
            d = aa**(1. / decimate)
            if decimate > 1:
                s = oasisAR1(y - b, d, lam=lam * (1 - aa) / (1 - d))[1]
            r = estimate_time_constant(s, 1, fudge_factor=.98)[0]
            g[0] = d + r
            g[1] = -d * r
            g11 = (np.exp(log(d) * np.arange(1, T + 1)) -
                   np.exp(log(r) * np.arange(1, T + 1))) / (d - r)
            g12 = np.append(0, g[1] * g11[:-1])
            g11g11 = np.cumsum(g11 * g11)
            g11g12 = np.cumsum(g11 * g12)
            Sg11 = np.cumsum(g11)
            f_lam = 1 - g[0] - g[1]
        elif decimate > 1:
            s = oasisAR1(y - b, d, lam=lam * (1 - aa) / (1 - d))[1]
        lam *= (1 - d**decimate) / f_lam
        # s = oasisAR1(s, r)[1]
        # this window size seems necessary and sufficient
        ff = np.ravel(
            [a + np.arange(-2, 2) for a in np.where(s > s.max() / 10.)[0]])
        ff = np.unique(ff[(ff >= 0) * (ff < T)]).astype(int)
        mask = np.zeros(T, dtype=bool)
        mask[ff] = True
    else:
        b = np.percentile(y, 15) if optimize_b else 0
        lam = 2 * sn * np.linalg.norm(g11)
        mask = None
    if b_nonneg:
        b = max(b, 0)
    # run ONNLS
    c, s = onnls(y - b,
                 g,
                 lam=lam,
                 mask=mask,
                 shift=shift,
                 window=window,
                 tol=tol)
    g_converged = False
    if not optimize_b:  # don't optimize b, just the dual variable lambda and g if optimize_g
        for i in range(max_iter - 1):
            res = y - c
            RSS = res.dot(res)
            if np.abs(RSS - thresh) < 1e-4:
                break
            # calc shift dlam, here attributed to sparsity penalty
            tmp = np.empty(T)
            ls = np.append(np.where(s > 1e-6)[0], T)
            l = ls[0]
            tmp[:l] = (1 + d) / (1 + d**l) * np.exp(
                log(d) * np.arange(l))  # first pool
            for i, f in enumerate(ls[:-1]):  # all other pools
                l = ls[i + 1] - f - 1
                # if and elif correct last 2 time points for |s|_1 instead |c|_1
                if i == len(ls) - 2:  # last pool
                    tmp[f] = (1. / f_lam if l == 0 else
                              (Sg11[l] + g[1] / f_lam * g11[l - 1] +
                               (g[0] + g[1]) / f_lam * g11[l] -
                               g11g12[l] * tmp[f - 1]) / g11g11[l])
                # secondlast pool if last one has length 1
                elif i == len(ls) - 3 and ls[-2] == T - 1:
                    tmp[f] = (Sg11[l] + g[1] / f_lam * g11[l] -
                              g11g12[l] * tmp[f - 1]) / g11g11[l]
                else:  # all other pools
                    tmp[f] = (Sg11[l] - g11g12[l] * tmp[f - 1]) / g11g11[l]
                l += 1
                tmp[f + 1:f + l] = g11[1:l] * tmp[f] + g12[1:l] * tmp[f - 1]

            aa = tmp.dot(tmp)
            bb = res.dot(tmp)
            cc = RSS - thresh
            try:
                dlam = (-bb + sqrt(bb * bb - aa * cc)) / aa
            except:
                dlam = -bb / aa
            # perform shift
            lam += dlam / f_lam
            c, s = onnls(y,
                         g,
                         lam=lam,
                         mask=mask,
                         shift=shift,
                         window=window,
                         tol=tol)

            # update g
            if optimize_g and (not g_converged):
                lengths = np.where(s)[0][1:] - np.where(s)[0][:-1]

                def getRSS(y, opt):
                    ld, lr = opt
                    if ld < lr:
                        return 1e3 * thresh
                    d, r = exp(ld), exp(lr)
                    g1, g2 = d + r, -d * r
                    tmp = onnls(y, [g1, g2], lam,
                                mask=(s > 1e-2 * s.max()))[0] - y
                    return tmp.dot(tmp)

                result = minimize(lambda x: getRSS(y, x), (log(d), log(r)),
                                  bounds=((None, -1e-4), (None, -1e-3)),
                                  method='L-BFGS-B',
                                  options={
                                      'gtol': 1e-04,
                                      'maxiter': 10,
                                      'ftol': 1e-05
                                  })
                if abs(result['x'][1] - log(d)) < 1e-4:
                    g_converged = True
                ld, lr = result['x']
                d, r = exp(ld), exp(lr)
                g = (d + r, -d * r)
                c, s = onnls(y,
                             g,
                             lam=lam,
                             mask=mask,
                             shift=shift,
                             window=window,
                             tol=tol)

    else:  # optimize b
        db = max(np.mean(y - c), 0 if b_nonneg else -np.inf) - b
        b += db
        lam -= db / (1 - g[0] - g[1])
        for i in range(max_iter - 1):
            res = y - c - b
            RSS = res.dot(res)
            if np.abs(RSS - thresh) < 1e-4:
                break
            # calc shift db, here attributed to baseline
            tmp = np.empty(T)
            ls = np.append(np.where(s > 1e-6)[0], T)
            l = ls[0]
            tmp[:l] = (1 + d) / (1 + d**l) * np.exp(
                log(d) * np.arange(l))  # first pool
            for i, f in enumerate(ls[:-1]):  # all other pools
                l = ls[i + 1] - f
                tmp[f] = (Sg11[l - 1] -
                          g11g12[l - 1] * tmp[f - 1]) / g11g11[l - 1]
                tmp[f + 1:f + l] = g11[1:l] * tmp[f] + g12[1:l] * tmp[f - 1]
            tmp -= tmp.mean()
            aa = tmp.dot(tmp)
            bb = res.dot(tmp)
            cc = RSS - thresh
            try:
                db = (-bb + sqrt(bb * bb - aa * cc)) / aa
            except:
                db = -bb / aa
            # perform shift
            if b_nonneg:
                db = max(db, -b)
            b += db
            c, s = onnls(y - b,
                         g,
                         lam=lam,
                         mask=mask,
                         shift=shift,
                         window=window,
                         tol=tol)
            # update b and lam
            db = max(np.mean(y - c), 0 if b_nonneg else -np.inf) - b
            b += db
            lam -= db / f_lam

            # update g and b
            if optimize_g and (not g_converged):
                lengths = np.where(s)[0][1:] - np.where(s)[0][:-1]

                def getRSS(y, opt):
                    b, ld, lr = opt
                    if ld < lr:
                        return 1e3 * thresh
                    d, r = exp(ld), exp(lr)
                    g1, g2 = d + r, -d * r
                    tmp = b + onnls(
                        y - b, [g1, g2], lam, mask=(s > 1e-2 * s.max()))[0] - y
                    return tmp.dot(tmp)

                result = minimize(lambda x: getRSS(y, x), (b, log(d), log(r)),
                                  bounds=((0 if b_nonneg else None, None),
                                          (None, -1e-4), (None, -1e-3)),
                                  method='L-BFGS-B',
                                  options={
                                      'gtol': 1e-04,
                                      'maxiter': 10,
                                      'ftol': 1e-05
                                  })
                if abs(result['x'][1] - log(d)) < 1e-3:
                    g_converged = True
                b, ld, lr = result['x']
                d, r = exp(ld), exp(lr)
                g = (d + r, -d * r)
                c, s = onnls(y - b,
                             g,
                             lam=lam,
                             mask=mask,
                             shift=shift,
                             window=window,
                             tol=tol)
                # update b and lam
                db = max(np.mean(y - c), 0 if b_nonneg else -np.inf) - b
                b += db
                lam -= db

    if penalty == 0:  # get (locally optimal) L0 solution

        def c4smin(y, s, s_min):
            ls = np.append(np.where(s > s_min)[0], T)
            tmp = np.zeros_like(s)
            l = ls[0]  # first pool
            tmp[:l] = max(
                0,
                np.exp(log(d) * np.arange(l)).dot(y[:l]) * (1 - d * d) /
                (1 - d**(2 * l))) * np.exp(log(d) * np.arange(l))
            for i, f in enumerate(ls[:-1]):  # all other pools
                l = ls[i + 1] - f
                tmp[f] = (g11[:l].dot(y[f:f + l]) -
                          g11g12[l - 1] * tmp[f - 1]) / g11g11[l - 1]
                tmp[f + 1:f + l] = g11[1:l] * tmp[f] + g12[1:l] * tmp[f - 1]
            return tmp

        spikesizes = np.sort(s[s > 1e-6])
        i = len(spikesizes) // 2
        l = 0
        u = len(spikesizes) - 1
        while u - l > 1:
            s_min = spikesizes[i]
            tmp = c4smin(y - b, s, s_min)
            res = y - b - tmp
            RSS = res.dot(res)
            if RSS < thresh or i == 0:
                l = i
                i = (l + u) // 2
                res0 = tmp
            else:
                u = i
                i = (l + u) // 2
        if i > 0:
            c = res0
            s = np.append([0, 0], c[2:] - g[0] * c[1:-1] - g[1] * c[:-2])

    return c, s, b, g, lam
Example #4
0
    plt.xlim(0, 452)
    plt.ylabel(r'$s_{\min}$')
    plt.xlabel('Time [s]', labelpad=-10)
    plt.show()


# AR(1)

g = .95
sn = .3
Y, trueC, trueSpikes = gen_data()
y = Y[0]
N, T = Y.shape

c, s = constrained_foopsi(y, [g], sn)[:2]
c_t, s_t = oasisAR1(y, g, s_min=.55)
res = [np.where(oasisAR1(y, g, s_min=s0)[1] > 1e-2)[0]
       for s0 in np.arange(0, 1.1, .1)]
plotTrace(True)


# AR(2)

g = [1.7, -.712]
sn = 1.
Y, trueC, trueSpikes = gen_data(g, sn, seed=3)
rng = slice(150, 600)
trueC = trueC[:, rng]
trueSpikes = trueSpikes[:, rng]
y = Y[0, rng]
N, T = Y.shape
Example #5
0
ax = fig.add_axes([ax1, .31, 1 - ax1, .12])
plot_trace(n)

# do few more iterations
for _ in range(3):
    solution, active_set, lam = update_lam(y, solution, active_set, g, lam,
                                           sn * sn * len(y))
    solution, active_set, g = update_g(y, active_set, g, lam)

# plot converged results with comparison traces
ax = fig.add_axes([ax1, .07, 1 - ax1, .12])
sol_given_g = constrained_oasisAR1(y, .95, sn)[0]
estimated_g = estimate_parameters(y, p=1)[0][0]
print('estimated gamma via autocorrelation: ', estimated_g)
print('optimized gamma                    : ', g)
sol_PSD_g = oasisAR1(y, estimated_g, 0)[0]
# print((sol_PSD_g-y).dot(sol_PSD_g-y), sn*sn*T # renders constraint problem infeasible
plt.plot(sol_given_g, '--', c=col[6], label=r'true $\gamma$', zorder=11)
plt.plot(sol_PSD_g, c=col[5], label=r'$\gamma$ from autocovariance', zorder=10)
plt.legend(frameon=False, loc=(.1, .62), ncol=2)
plot_trace(n)
plt.xticks([300, 600, 900, 1200], [10, 20, 30, 40])
plt.xlabel('Time [s]', labelpad=-10)
plt.show()

print('correlation with ground truth calcium for   given   gamma ',
      np.corrcoef(sol_given_g, trueC[n])[0, 1])
print('correlation with ground truth calcium for estimated gamma ',
      np.corrcoef(sol_PSD_g, trueC[n])[0, 1])
print('correlation with ground truth calcium for optimized gamma ',
      np.corrcoef(solution, trueC[n])[0, 1])
Example #6
0
N, T = Y.shape

results = {}
for opt in [
        '-', 'l', 'lb', 'lbg', 'lbg10', 'lbg5', 'lbg_ds', 'lbg10_ds', 'lbg5_ds'
]:
    results[opt] = {}
    results[opt]['time'] = []
    results[opt]['distance'] = []
    results[opt]['correlation'] = []
    for i, y in enumerate(Y):
        g, sn = estimate_parameters(y, p=1, fudge_factor=.99)
        lam = 2 * sn * (1 - g * g)**(-.5)
        b = np.percentile(y, 15)
        if opt == '-':
            foo = lambda y: oasisAR1(y - b, g, lam)
        elif opt == 'l':
            foo = lambda y: constrained_oasisAR1(y - b, g, sn)
        elif opt == 'lb':
            foo = lambda y: constrained_oasisAR1(y, g, sn, optimize_b=True)
        elif opt == 'lbg':
            foo = lambda y: constrained_oasisAR1(
                y, g, sn, optimize_b=True, optimize_g=len(y))
        elif opt == 'lbg10':
            foo = lambda y: constrained_oasisAR1(
                y, g, sn, optimize_b=True, optimize_g=10)
        elif opt == 'lbg5':
            foo = lambda y: constrained_oasisAR1(
                y, g, sn, optimize_b=True, optimize_g=5)
        elif opt == 'lbg_ds':
            foo = lambda y: constrained_oasisAR1(
import argparse
import pandas as pd

parser = argparse.ArgumentParser(description='input to oasis wrapper')
parser.add_argument('y_file')
parser.add_argument('gam', type=float)
parser.add_argument('theta', type=float)
parser.add_argument('type', type=str)
parser.add_argument('out_file')

args = parser.parse_args()

# read in tmp data from y_file

import os
print(os.getcwd())

print("trying to open file %s" % args.y_file)
abs_file_path = args.y_file

df = pd.read_csv(abs_file_path, sep=',', header=None)
y = df.values.flatten()

# wraps the oasis ar1 function for easy calls from R
if args.type == 'penalized':
    fit = oasisAR1(y, args.gam, lam=args.theta)
else:
    fit = oasisAR1(y, args.gam, s_min=args.theta)

df = pd.DataFrame([fit[0], fit[1]]).transpose()
df.to_csv(args.out_file, index=False, header=['calcium', 'spikes'])
def cv_select_lambda(y,s_vec,t_vec,lambda_seq,L_min_seq,gamma = 0.96):
    
    train_y = y[0::2]
    test_y = y[1::2]
    mse_c_l0 = []
    mse_c_l1 = []
    result_dict = dict()
    
    l_0_spike_collection = []
    l_1_spike_collection = []    
    l_0_vr_list = []
    l_1_vr_list = []
    l_0_vp_list = []
    l_1_vp_list = []
    
    for lam in lambda_seq:
        print('lambda = ',lam)
        c_t_l1 , s_t_l1 = oasisAR1(train_y, g = gamma,lam =lam, s_min=0)
        _, s_t_l0,c_t_l0 = dp_ell_0_active_sol(train_y, alpha = lam,gamma = gamma)
        
        test_c_t_l0 = [(a + b) / 2 for a, b in \
                       zip(c_t_l0[0:-1], c_t_l0[1:])]
        test_c_t_l1 = [(a + b) / 2 for a, b in \
                       zip(c_t_l1[0:-1], c_t_l1[1:])]
        
        mse_c_l0.append(mean_squared_error(test_c_t_l0,test_y[0:-1]))
        mse_c_l1.append(mean_squared_error(test_c_t_l1,test_y[0:-1]))
        
        
        c_t_l1 , s_t_l1 = oasisAR1(y, g = gamma,lam =lam, s_min=0)
        _, s_t_l0,c_t_l0 = dp_ell_0_active_sol(y, alpha = lam,gamma = gamma)
        
        t_scale = len(y)
        l_0_spike = s_t_l0
        l_0_spike_collection.append(t_vec[l_0_spike])
        l_1_spike_collection.append([t_vec[np.where(s_t_l1>L_min)] for L_min in L_min_seq])
        spike_truth = SpikeTrain(s_vec*s, t_stop=t_vec[-1])
        l_1_spike_list = [SpikeTrain(t_vec[np.where(s_t_l1>L_min)]*s, t_stop=t_vec[-1])\
                          for L_min in L_min_seq]
        l_0_spike = SpikeTrain(t_vec[l_0_spike]*s, t_stop=t_vec[-1])
        
        # compute post-processing estimators
        l_1_vr = [elephant.spike_train_dissimilarity.van_rossum_dist([l_1_spike,\
                     spike_truth],tau=(2.)*s)[0,1] for l_1_spike in l_1_spike_list]
        l_1_vp = [elephant.spike_train_dissimilarity.victor_purpura_dist([l_1_spike,\
                      spike_truth])[0,1] for l_1_spike in  l_1_spike_list]
        
        l_0_vr = elephant.spike_train_dissimilarity.van_rossum_dist([l_0_spike,spike_truth],\
                                   tau=(2.)*s)[0,1]
        l_0_vp = elephant.spike_train_dissimilarity.\
                    victor_purpura_dist([l_0_spike,spike_truth])[0,1]
        
        l_0_vr_list.append(l_0_vr)
        l_1_vr_list.append(l_1_vr)
        l_0_vp_list.append(l_0_vp)
        l_1_vp_list.append(l_1_vp)

    # store results
    result_dict['lambda_seq'] = lambda_seq
    result_dict['L_min_seq'] = L_min_seq # :(
    result_dict['mse_c_l0'] = mse_c_l0 
    result_dict['mse_c_l1'] = mse_c_l1 
    result_dict['l_0_vr_list'] = l_0_vr_list
    result_dict['l_1_vr_list'] = l_1_vr_list
    result_dict['l_0_vp_list'] = l_0_vp_list
    result_dict['l_1_vp_list'] = l_1_vp_list
    result_dict['l_1_spike_collection'] = l_1_spike_collection
    result_dict['l_0_spike_collection'] = l_0_spike_collection
    return(result_dict)