def cv_select_lambda(y, s_vec, t_vec, lambda_seq, L_min_seq, gamma=0.96): train_y = y[0::2] test_y = y[1::2] mse_c_l0 = [] mse_c_l1 = [] spike_l0 = [] spike_l1 = [] result_dict = dict() for lam in lambda_seq: print('lambda = ', lam) c_t_l1, s_t_l1 = oasisAR1(train_y, g=gamma, lam=lam, s_min=0) _, s_t_l0, c_t_l0 = dp_ell_0_active_sol(train_y, alpha=lam, gamma=gamma) test_c_t_l0 = [(a + b) / 2 for a, b in \ zip(c_t_l0[0:-1], c_t_l0[1:])] test_c_t_l1 = [(a + b) / 2 for a, b in \ zip(c_t_l1[0:-1], c_t_l1[1:])] mse_c_l0.append(mean_squared_error(test_c_t_l0, test_y[0:-1])) mse_c_l1.append(mean_squared_error(test_c_t_l1, test_y[0:-1])) c_t_l1, s_t_l1 = oasisAR1(y, g=gamma, lam=lam, s_min=0) _, s_t_l0, c_t_l0 = dp_ell_0_active_sol(y, alpha=lam, gamma=gamma) spike_l0.append(t_vec[s_t_l0]) spike_l1.append([t_vec[s_t_l1 > L_min] for L_min in L_min_seq]) result_dict['lambda_seq'] = lambda_seq result_dict['L_min_seq'] = L_min_seq # :( result_dict['mse_c_l0'] = mse_c_l0 result_dict['mse_c_l1'] = mse_c_l1 result_dict['spike_l0'] = spike_l0 result_dict['spike_l1'] = spike_l1 return (result_dict)
] # real data from Chen et al 2013, available at following URL # https://portal.nersc.gov/project/crcns/download/cai-1/GCaMP6s_9cells_Chen2013/processed_data.tar.gz filename = "/Users/joe/Downloads/data_20120627_cell2_002.mat" ################ #### Traces #### ################ # AR(1) g = .95 sn = .3 Y, trueC, trueSpikes = gen_data() N, T = Y.shape result_oasis = oasisAR1(Y[0], g=g, lam=2.4) result_foopsi = foopsi(Y[0], g=[g], lam=2.4) fig = plt.figure(figsize=(20, 5.5)) fig.add_axes([.038, .57, .96, .42]) plt.plot(result_oasis[0], c=col[0], label='OASIS') plt.plot(result_foopsi[0], '--', c=col[6], label='CVXPY') plt.plot(trueC[0], c=col[2], label='Truth', zorder=-5) plt.plot(Y[0], c=col[7], alpha=.7, zorder=-10, lw=1, label='Data') plt.legend(frameon=False, ncol=4, loc=(.275, .82)) plt.gca().set_xticklabels([]) simpleaxis(plt.gca()) plt.ylim(Y[0].min(), Y[0].max()) plt.yticks([0, int(Y[0].max())], [0, int(Y[0].max())]) plt.xticks(range(750, 3000, 750), [''] * 3) plt.ylabel('Fluor.')
def constrained_onnlsAR2(y, g, sn, optimize_b=True, b_nonneg=True, optimize_g=0, decimate=5, shift=100, window=None, tol=1e-9, max_iter=1, penalty=1): """ Infer the most likely discretized spike train underlying an AR(2) fluorescence trace Solves the noise constrained sparse non-negative deconvolution problem min |s|_1 subject to |c-y|^2 = sn^2 T and s_t = c_t-g1 c_{t-1}-g2 c_{t-2} >= 0 Parameters ---------- y : array of float One dimensional array containing the fluorescence intensities (with baseline already subtracted) with one entry per time-bin. g : (float, float) Parameters of the AR(2) process that models the fluorescence impulse response. sn : float Standard deviation of the noise distribution. optimize_b : bool, optional, default True Optimize baseline if True else it is set to 0, see y. b_nonneg: bool, optional, default True Enforce strictly non-negative baseline if True. optimize_g : int, optional, default 0 Number of large, isolated events to consider for optimizing g. No optimization if optimize_g=0. decimate : int, optional, default 5 Decimation factor for estimating hyper-parameters faster on decimated data. shift : int, optional, default 100 Number of frames by which to shift window from on run of NNLS to the next. window : int, optional, default None (200 or larger dependend on g) Window size. tol : float, optional, default 1e-9 Tolerance parameter. max_iter : int, optional, default 1 Maximal number of iterations. penalty : int, optional, default 1 Sparsity penalty. 1: min |s|_1 0: min |s|_0 Returns ------- c : array of float The inferred denoised fluorescence signal at each time-bin. s : array of float Discretized deconvolved neural activity (spikes). b : float Fluorescence baseline value. (g1, g2) : tuple of float Parameters of the AR(2) process that models the fluorescence impulse response. lam : float Sparsity penalty parameter lambda of dual problem. References ---------- * Friedrich J and Paninski L, NIPS 2016 * Friedrich J, Zhou P, and Paninski L, PLOS Computational Biology 2017 """ T = len(y) d = (g[0] + sqrt(g[0] * g[0] + 4 * g[1])) / 2 r = (g[0] - sqrt(g[0] * g[0] + 4 * g[1])) / 2 if window is None: window = int(min(T, max(200, -5 / log(d)))) if not optimize_g: g11 = (np.exp(log(d) * np.arange(1, T + 1)) * np.arange(1, T + 1)) \ if d == r else \ (np.exp(log(d) * np.arange(1, T + 1)) - np.exp(log(r) * np.arange(1, T + 1))) / (d - r) g12 = np.append(0, g[1] * g11[:-1]) g11g11 = np.cumsum(g11 * g11) g11g12 = np.cumsum(g11 * g12) Sg11 = np.cumsum(g11) f_lam = 1 - g[0] - g[1] elif decimate == 0: # need to run AR1 anyways for estimating AR coeffs decimate = 1 thresh = sn * sn * T # get initial estimate of b and lam on downsampled data using AR1 model if decimate > 0: _, s, b, aa, lam = constrained_oasisAR1( y[:len(y) // decimate * decimate].reshape(-1, decimate).mean(1), d**decimate, sn / sqrt(decimate), optimize_b=optimize_b, b_nonneg=b_nonneg, optimize_g=optimize_g) if optimize_g: d = aa**(1. / decimate) if decimate > 1: s = oasisAR1(y - b, d, lam=lam * (1 - aa) / (1 - d))[1] r = estimate_time_constant(s, 1, fudge_factor=.98)[0] g[0] = d + r g[1] = -d * r g11 = (np.exp(log(d) * np.arange(1, T + 1)) - np.exp(log(r) * np.arange(1, T + 1))) / (d - r) g12 = np.append(0, g[1] * g11[:-1]) g11g11 = np.cumsum(g11 * g11) g11g12 = np.cumsum(g11 * g12) Sg11 = np.cumsum(g11) f_lam = 1 - g[0] - g[1] elif decimate > 1: s = oasisAR1(y - b, d, lam=lam * (1 - aa) / (1 - d))[1] lam *= (1 - d**decimate) / f_lam # s = oasisAR1(s, r)[1] # this window size seems necessary and sufficient ff = np.ravel( [a + np.arange(-2, 2) for a in np.where(s > s.max() / 10.)[0]]) ff = np.unique(ff[(ff >= 0) * (ff < T)]).astype(int) mask = np.zeros(T, dtype=bool) mask[ff] = True else: b = np.percentile(y, 15) if optimize_b else 0 lam = 2 * sn * np.linalg.norm(g11) mask = None if b_nonneg: b = max(b, 0) # run ONNLS c, s = onnls(y - b, g, lam=lam, mask=mask, shift=shift, window=window, tol=tol) g_converged = False if not optimize_b: # don't optimize b, just the dual variable lambda and g if optimize_g for i in range(max_iter - 1): res = y - c RSS = res.dot(res) if np.abs(RSS - thresh) < 1e-4: break # calc shift dlam, here attributed to sparsity penalty tmp = np.empty(T) ls = np.append(np.where(s > 1e-6)[0], T) l = ls[0] tmp[:l] = (1 + d) / (1 + d**l) * np.exp( log(d) * np.arange(l)) # first pool for i, f in enumerate(ls[:-1]): # all other pools l = ls[i + 1] - f - 1 # if and elif correct last 2 time points for |s|_1 instead |c|_1 if i == len(ls) - 2: # last pool tmp[f] = (1. / f_lam if l == 0 else (Sg11[l] + g[1] / f_lam * g11[l - 1] + (g[0] + g[1]) / f_lam * g11[l] - g11g12[l] * tmp[f - 1]) / g11g11[l]) # secondlast pool if last one has length 1 elif i == len(ls) - 3 and ls[-2] == T - 1: tmp[f] = (Sg11[l] + g[1] / f_lam * g11[l] - g11g12[l] * tmp[f - 1]) / g11g11[l] else: # all other pools tmp[f] = (Sg11[l] - g11g12[l] * tmp[f - 1]) / g11g11[l] l += 1 tmp[f + 1:f + l] = g11[1:l] * tmp[f] + g12[1:l] * tmp[f - 1] aa = tmp.dot(tmp) bb = res.dot(tmp) cc = RSS - thresh try: dlam = (-bb + sqrt(bb * bb - aa * cc)) / aa except: dlam = -bb / aa # perform shift lam += dlam / f_lam c, s = onnls(y, g, lam=lam, mask=mask, shift=shift, window=window, tol=tol) # update g if optimize_g and (not g_converged): lengths = np.where(s)[0][1:] - np.where(s)[0][:-1] def getRSS(y, opt): ld, lr = opt if ld < lr: return 1e3 * thresh d, r = exp(ld), exp(lr) g1, g2 = d + r, -d * r tmp = onnls(y, [g1, g2], lam, mask=(s > 1e-2 * s.max()))[0] - y return tmp.dot(tmp) result = minimize(lambda x: getRSS(y, x), (log(d), log(r)), bounds=((None, -1e-4), (None, -1e-3)), method='L-BFGS-B', options={ 'gtol': 1e-04, 'maxiter': 10, 'ftol': 1e-05 }) if abs(result['x'][1] - log(d)) < 1e-4: g_converged = True ld, lr = result['x'] d, r = exp(ld), exp(lr) g = (d + r, -d * r) c, s = onnls(y, g, lam=lam, mask=mask, shift=shift, window=window, tol=tol) else: # optimize b db = max(np.mean(y - c), 0 if b_nonneg else -np.inf) - b b += db lam -= db / (1 - g[0] - g[1]) for i in range(max_iter - 1): res = y - c - b RSS = res.dot(res) if np.abs(RSS - thresh) < 1e-4: break # calc shift db, here attributed to baseline tmp = np.empty(T) ls = np.append(np.where(s > 1e-6)[0], T) l = ls[0] tmp[:l] = (1 + d) / (1 + d**l) * np.exp( log(d) * np.arange(l)) # first pool for i, f in enumerate(ls[:-1]): # all other pools l = ls[i + 1] - f tmp[f] = (Sg11[l - 1] - g11g12[l - 1] * tmp[f - 1]) / g11g11[l - 1] tmp[f + 1:f + l] = g11[1:l] * tmp[f] + g12[1:l] * tmp[f - 1] tmp -= tmp.mean() aa = tmp.dot(tmp) bb = res.dot(tmp) cc = RSS - thresh try: db = (-bb + sqrt(bb * bb - aa * cc)) / aa except: db = -bb / aa # perform shift if b_nonneg: db = max(db, -b) b += db c, s = onnls(y - b, g, lam=lam, mask=mask, shift=shift, window=window, tol=tol) # update b and lam db = max(np.mean(y - c), 0 if b_nonneg else -np.inf) - b b += db lam -= db / f_lam # update g and b if optimize_g and (not g_converged): lengths = np.where(s)[0][1:] - np.where(s)[0][:-1] def getRSS(y, opt): b, ld, lr = opt if ld < lr: return 1e3 * thresh d, r = exp(ld), exp(lr) g1, g2 = d + r, -d * r tmp = b + onnls( y - b, [g1, g2], lam, mask=(s > 1e-2 * s.max()))[0] - y return tmp.dot(tmp) result = minimize(lambda x: getRSS(y, x), (b, log(d), log(r)), bounds=((0 if b_nonneg else None, None), (None, -1e-4), (None, -1e-3)), method='L-BFGS-B', options={ 'gtol': 1e-04, 'maxiter': 10, 'ftol': 1e-05 }) if abs(result['x'][1] - log(d)) < 1e-3: g_converged = True b, ld, lr = result['x'] d, r = exp(ld), exp(lr) g = (d + r, -d * r) c, s = onnls(y - b, g, lam=lam, mask=mask, shift=shift, window=window, tol=tol) # update b and lam db = max(np.mean(y - c), 0 if b_nonneg else -np.inf) - b b += db lam -= db if penalty == 0: # get (locally optimal) L0 solution def c4smin(y, s, s_min): ls = np.append(np.where(s > s_min)[0], T) tmp = np.zeros_like(s) l = ls[0] # first pool tmp[:l] = max( 0, np.exp(log(d) * np.arange(l)).dot(y[:l]) * (1 - d * d) / (1 - d**(2 * l))) * np.exp(log(d) * np.arange(l)) for i, f in enumerate(ls[:-1]): # all other pools l = ls[i + 1] - f tmp[f] = (g11[:l].dot(y[f:f + l]) - g11g12[l - 1] * tmp[f - 1]) / g11g11[l - 1] tmp[f + 1:f + l] = g11[1:l] * tmp[f] + g12[1:l] * tmp[f - 1] return tmp spikesizes = np.sort(s[s > 1e-6]) i = len(spikesizes) // 2 l = 0 u = len(spikesizes) - 1 while u - l > 1: s_min = spikesizes[i] tmp = c4smin(y - b, s, s_min) res = y - b - tmp RSS = res.dot(res) if RSS < thresh or i == 0: l = i i = (l + u) // 2 res0 = tmp else: u = i i = (l + u) // 2 if i > 0: c = res0 s = np.append([0, 0], c[2:] - g[0] * c[1:-1] - g[1] * c[:-2]) return c, s, b, g, lam
plt.xlim(0, 452) plt.ylabel(r'$s_{\min}$') plt.xlabel('Time [s]', labelpad=-10) plt.show() # AR(1) g = .95 sn = .3 Y, trueC, trueSpikes = gen_data() y = Y[0] N, T = Y.shape c, s = constrained_foopsi(y, [g], sn)[:2] c_t, s_t = oasisAR1(y, g, s_min=.55) res = [np.where(oasisAR1(y, g, s_min=s0)[1] > 1e-2)[0] for s0 in np.arange(0, 1.1, .1)] plotTrace(True) # AR(2) g = [1.7, -.712] sn = 1. Y, trueC, trueSpikes = gen_data(g, sn, seed=3) rng = slice(150, 600) trueC = trueC[:, rng] trueSpikes = trueSpikes[:, rng] y = Y[0, rng] N, T = Y.shape
ax = fig.add_axes([ax1, .31, 1 - ax1, .12]) plot_trace(n) # do few more iterations for _ in range(3): solution, active_set, lam = update_lam(y, solution, active_set, g, lam, sn * sn * len(y)) solution, active_set, g = update_g(y, active_set, g, lam) # plot converged results with comparison traces ax = fig.add_axes([ax1, .07, 1 - ax1, .12]) sol_given_g = constrained_oasisAR1(y, .95, sn)[0] estimated_g = estimate_parameters(y, p=1)[0][0] print('estimated gamma via autocorrelation: ', estimated_g) print('optimized gamma : ', g) sol_PSD_g = oasisAR1(y, estimated_g, 0)[0] # print((sol_PSD_g-y).dot(sol_PSD_g-y), sn*sn*T # renders constraint problem infeasible plt.plot(sol_given_g, '--', c=col[6], label=r'true $\gamma$', zorder=11) plt.plot(sol_PSD_g, c=col[5], label=r'$\gamma$ from autocovariance', zorder=10) plt.legend(frameon=False, loc=(.1, .62), ncol=2) plot_trace(n) plt.xticks([300, 600, 900, 1200], [10, 20, 30, 40]) plt.xlabel('Time [s]', labelpad=-10) plt.show() print('correlation with ground truth calcium for given gamma ', np.corrcoef(sol_given_g, trueC[n])[0, 1]) print('correlation with ground truth calcium for estimated gamma ', np.corrcoef(sol_PSD_g, trueC[n])[0, 1]) print('correlation with ground truth calcium for optimized gamma ', np.corrcoef(solution, trueC[n])[0, 1])
N, T = Y.shape results = {} for opt in [ '-', 'l', 'lb', 'lbg', 'lbg10', 'lbg5', 'lbg_ds', 'lbg10_ds', 'lbg5_ds' ]: results[opt] = {} results[opt]['time'] = [] results[opt]['distance'] = [] results[opt]['correlation'] = [] for i, y in enumerate(Y): g, sn = estimate_parameters(y, p=1, fudge_factor=.99) lam = 2 * sn * (1 - g * g)**(-.5) b = np.percentile(y, 15) if opt == '-': foo = lambda y: oasisAR1(y - b, g, lam) elif opt == 'l': foo = lambda y: constrained_oasisAR1(y - b, g, sn) elif opt == 'lb': foo = lambda y: constrained_oasisAR1(y, g, sn, optimize_b=True) elif opt == 'lbg': foo = lambda y: constrained_oasisAR1( y, g, sn, optimize_b=True, optimize_g=len(y)) elif opt == 'lbg10': foo = lambda y: constrained_oasisAR1( y, g, sn, optimize_b=True, optimize_g=10) elif opt == 'lbg5': foo = lambda y: constrained_oasisAR1( y, g, sn, optimize_b=True, optimize_g=5) elif opt == 'lbg_ds': foo = lambda y: constrained_oasisAR1(
import argparse import pandas as pd parser = argparse.ArgumentParser(description='input to oasis wrapper') parser.add_argument('y_file') parser.add_argument('gam', type=float) parser.add_argument('theta', type=float) parser.add_argument('type', type=str) parser.add_argument('out_file') args = parser.parse_args() # read in tmp data from y_file import os print(os.getcwd()) print("trying to open file %s" % args.y_file) abs_file_path = args.y_file df = pd.read_csv(abs_file_path, sep=',', header=None) y = df.values.flatten() # wraps the oasis ar1 function for easy calls from R if args.type == 'penalized': fit = oasisAR1(y, args.gam, lam=args.theta) else: fit = oasisAR1(y, args.gam, s_min=args.theta) df = pd.DataFrame([fit[0], fit[1]]).transpose() df.to_csv(args.out_file, index=False, header=['calcium', 'spikes'])
def cv_select_lambda(y,s_vec,t_vec,lambda_seq,L_min_seq,gamma = 0.96): train_y = y[0::2] test_y = y[1::2] mse_c_l0 = [] mse_c_l1 = [] result_dict = dict() l_0_spike_collection = [] l_1_spike_collection = [] l_0_vr_list = [] l_1_vr_list = [] l_0_vp_list = [] l_1_vp_list = [] for lam in lambda_seq: print('lambda = ',lam) c_t_l1 , s_t_l1 = oasisAR1(train_y, g = gamma,lam =lam, s_min=0) _, s_t_l0,c_t_l0 = dp_ell_0_active_sol(train_y, alpha = lam,gamma = gamma) test_c_t_l0 = [(a + b) / 2 for a, b in \ zip(c_t_l0[0:-1], c_t_l0[1:])] test_c_t_l1 = [(a + b) / 2 for a, b in \ zip(c_t_l1[0:-1], c_t_l1[1:])] mse_c_l0.append(mean_squared_error(test_c_t_l0,test_y[0:-1])) mse_c_l1.append(mean_squared_error(test_c_t_l1,test_y[0:-1])) c_t_l1 , s_t_l1 = oasisAR1(y, g = gamma,lam =lam, s_min=0) _, s_t_l0,c_t_l0 = dp_ell_0_active_sol(y, alpha = lam,gamma = gamma) t_scale = len(y) l_0_spike = s_t_l0 l_0_spike_collection.append(t_vec[l_0_spike]) l_1_spike_collection.append([t_vec[np.where(s_t_l1>L_min)] for L_min in L_min_seq]) spike_truth = SpikeTrain(s_vec*s, t_stop=t_vec[-1]) l_1_spike_list = [SpikeTrain(t_vec[np.where(s_t_l1>L_min)]*s, t_stop=t_vec[-1])\ for L_min in L_min_seq] l_0_spike = SpikeTrain(t_vec[l_0_spike]*s, t_stop=t_vec[-1]) # compute post-processing estimators l_1_vr = [elephant.spike_train_dissimilarity.van_rossum_dist([l_1_spike,\ spike_truth],tau=(2.)*s)[0,1] for l_1_spike in l_1_spike_list] l_1_vp = [elephant.spike_train_dissimilarity.victor_purpura_dist([l_1_spike,\ spike_truth])[0,1] for l_1_spike in l_1_spike_list] l_0_vr = elephant.spike_train_dissimilarity.van_rossum_dist([l_0_spike,spike_truth],\ tau=(2.)*s)[0,1] l_0_vp = elephant.spike_train_dissimilarity.\ victor_purpura_dist([l_0_spike,spike_truth])[0,1] l_0_vr_list.append(l_0_vr) l_1_vr_list.append(l_1_vr) l_0_vp_list.append(l_0_vp) l_1_vp_list.append(l_1_vp) # store results result_dict['lambda_seq'] = lambda_seq result_dict['L_min_seq'] = L_min_seq # :( result_dict['mse_c_l0'] = mse_c_l0 result_dict['mse_c_l1'] = mse_c_l1 result_dict['l_0_vr_list'] = l_0_vr_list result_dict['l_1_vr_list'] = l_1_vr_list result_dict['l_0_vp_list'] = l_0_vp_list result_dict['l_1_vp_list'] = l_1_vp_list result_dict['l_1_spike_collection'] = l_1_spike_collection result_dict['l_0_spike_collection'] = l_0_spike_collection return(result_dict)